Coverage for slidge/core/dispatcher/util.py: 91%
113 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1import logging
2from functools import wraps
3from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar
5from slixmpp import JID, Iq, Message, Presence
6from slixmpp.exceptions import XMPPError
7from slixmpp.xmlstream import StanzaBase
9from ...util.types import Recipient, RecipientType
10from ..session import BaseSession
12if TYPE_CHECKING:
13 from slidge import BaseGateway
14 from slidge.group import LegacyMUC
17class Ignore(BaseException):
18 pass
21class DispatcherMixin:
22 __slots__: list[str] = []
24 def __init__(self, xmpp: "BaseGateway") -> None:
25 self.xmpp = xmpp # type:ignore[misc]
27 async def _get_session(
28 self,
29 stanza: Message | Presence | Iq,
30 timeout: int | None = 10,
31 wait_for_ready: bool = True,
32 logged: bool = False,
33 ) -> BaseSession:
34 xmpp = self.xmpp
35 if stanza.get_from().server == xmpp.boundjid.bare:
36 log.debug("Ignoring echo")
37 raise Ignore
38 if (
39 isinstance(stanza, Message)
40 and stanza.get_type() == "chat"
41 and stanza.get_to() == xmpp.boundjid.bare
42 ):
43 log.debug("Ignoring message to component")
44 raise Ignore
45 session = await self._get_session_from_jid(
46 stanza.get_from(), timeout, wait_for_ready, logged
47 )
48 if isinstance(stanza, Message) and _ignore(session, stanza):
49 raise Ignore
50 return session
52 async def _get_session_from_jid(
53 self,
54 jid: JID,
55 timeout: int | None = 10,
56 wait_for_ready: bool = True,
57 logged: bool = False,
58 ) -> BaseSession:
59 session = self.xmpp.get_session_from_jid(jid)
60 if session is None:
61 raise XMPPError("registration-required")
62 if logged:
63 session.raise_if_not_logged()
64 if wait_for_ready:
65 await session.wait_for_ready(timeout)
66 return session
68 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> "LegacyMUC":
69 ito = iq.get_to()
70 if ito == self.xmpp.boundjid.bare:
71 raise XMPPError("bad-request", text="This is only handled for MUCs")
73 session = await self._get_session(iq, logged=True)
74 muc = await session.bookmarks.by_jid(ito)
75 return muc
77 def _xmpp_msg_id_to_legacy(
78 self, session: "BaseSession", xmpp_id: str, recipient: Recipient
79 ):
80 with self.xmpp.store.session() as orm:
81 sent = self.xmpp.store.id_map.get_legacy(
82 orm, recipient.stored.id, xmpp_id, False
83 )
84 if sent is not None:
85 return self.xmpp.LEGACY_MSG_ID_TYPE(sent)
87 try:
88 return session.xmpp_to_legacy_msg_id(xmpp_id)
89 except XMPPError:
90 raise
91 except Exception as e:
92 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e)
93 raise XMPPError(
94 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID."
95 )
97 async def _get_session_recipient_thread(
98 self, msg: Message
99 ) -> tuple["BaseSession", Recipient, int | str]:
100 session = await self._get_session(msg)
101 e: Recipient = await get_recipient(session, msg)
102 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e)
103 return session, e, legacy_thread
105 async def _xmpp_to_legacy_thread(
106 self, session: "BaseSession", msg: Message, recipient: RecipientType
107 ):
108 xmpp_thread = msg["thread"]
109 if not xmpp_thread:
110 return None
112 if session.MESSAGE_IDS_ARE_THREAD_IDS:
113 return self._xmpp_msg_id_to_legacy(session, xmpp_thread, recipient)
115 with session.xmpp.store.session() as orm:
116 legacy_thread_str = session.xmpp.store.id_map.get_thread(
117 orm, recipient.stored.id, xmpp_thread, recipient.is_group
118 )
119 if legacy_thread_str is not None:
120 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str)
121 async with session.thread_creation_lock:
122 legacy_thread = await recipient.create_thread(xmpp_thread)
123 with session.xmpp.store.session() as orm:
124 session.xmpp.store.id_map.set_thread(
125 orm,
126 recipient.stored.id,
127 str(legacy_thread),
128 xmpp_thread,
129 recipient.is_group,
130 )
131 orm.commit()
132 return legacy_thread
135def _ignore(session: "BaseSession", msg: Message) -> bool:
136 i = msg.get_id()
137 if i.startswith("slidge-carbon-"):
138 return True
139 if i not in session.ignore_messages:
140 return False
141 session.log.debug("Ignored sent carbon: %s", i)
142 session.ignore_messages.remove(i)
143 return True
146async def get_recipient(session: "BaseSession", m: Message) -> RecipientType:
147 session.raise_if_not_logged()
148 if m.get_type() == "groupchat":
149 muc = await session.bookmarks.by_jid(m.get_to())
150 r = m.get_from().resource
151 if r not in muc.get_user_resources():
152 session.create_task(muc.kick_resource(r))
153 raise XMPPError("not-acceptable", "You are not connected to this chat")
154 return muc
155 else:
156 return await session.contacts.by_jid(m.get_to())
159StanzaType = TypeVar("StanzaType", bound=StanzaBase)
160HandlerType = Callable[[Any, StanzaType], Awaitable[None]]
163def exceptions_to_xmpp_errors(cb: HandlerType) -> HandlerType:
164 @wraps(cb)
165 async def wrapped(*args):
166 try:
167 await cb(*args)
168 except Ignore:
169 pass
170 except XMPPError:
171 raise
172 except NotImplementedError:
173 log.debug("NotImplementedError raised in %s", cb)
174 raise XMPPError(
175 "feature-not-implemented", "Not implemented by the legacy module"
176 )
177 except Exception as e:
178 log.error("Failed to handle incoming stanza: %s", args, exc_info=e)
179 raise XMPPError("internal-server-error", str(e))
181 return wrapped
184log = logging.getLogger(__name__)