Coverage for slidge / core / dispatcher / util.py: 91%
119 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1import logging
2from collections.abc import Awaitable, Callable
3from functools import wraps
4from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
6from slixmpp import JID, Iq, Message, Presence
7from slixmpp.exceptions import XMPPError
8from slixmpp.xmlstream import StanzaBase
10from ...contact.roster import ContactIsUser
11from ...util.types import (
12 AnyMUC,
13 AnyRecipient,
14 AnySession,
15 LegacyMessageType,
16 RecipientType,
17)
18from ..session import BaseSession
20if TYPE_CHECKING:
21 from slidge import BaseGateway
24class Ignore(BaseException):
25 pass
28class DispatcherMixin:
29 __slots__: list[str] = []
31 def __init__(self, xmpp: "BaseGateway") -> None:
32 self.xmpp = xmpp # type:ignore[misc]
34 async def _get_session(
35 self,
36 stanza: Message | Presence | Iq,
37 timeout: int | None = 10,
38 wait_for_ready: bool = True,
39 logged: bool = False,
40 ) -> AnySession:
41 xmpp = self.xmpp
42 if stanza.get_from().server == xmpp.boundjid.bare:
43 log.debug("Ignoring echo")
44 raise Ignore
45 if (
46 isinstance(stanza, Message)
47 and stanza.get_type() == "chat"
48 and stanza.get_to() == xmpp.boundjid.bare
49 ):
50 log.debug("Ignoring message to component")
51 raise Ignore
52 session = await self._get_session_from_jid(
53 stanza.get_from(), timeout, wait_for_ready, logged
54 )
55 if isinstance(stanza, Message) and _ignore(session, stanza):
56 raise Ignore
57 return session
59 async def _get_session_from_jid(
60 self,
61 jid: JID,
62 timeout: int | None = 10,
63 wait_for_ready: bool = True,
64 logged: bool = False,
65 ) -> AnySession:
66 session = self.xmpp.get_session_from_jid(jid)
67 if session is None:
68 raise XMPPError("registration-required")
69 if logged:
70 session.raise_if_not_logged()
71 if wait_for_ready:
72 await session.wait_for_ready(timeout)
73 return session
75 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> AnyMUC:
76 ito = iq.get_to()
77 if ito == self.xmpp.boundjid.bare:
78 raise XMPPError("bad-request", text="This is only handled for MUCs")
80 session = await self._get_session(iq, logged=True)
81 muc = await session.bookmarks.by_jid(ito)
82 return muc # type:ignore[no-any-return]
84 def _xmpp_msg_id_to_legacy(
85 self,
86 session: "BaseSession[LegacyMessageType, Any]",
87 xmpp_id: str,
88 recipient: AnyRecipient,
89 origin: bool = False,
90 ) -> LegacyMessageType:
91 with self.xmpp.store.session() as orm:
92 sent = self.xmpp.store.id_map.get_legacy(
93 orm, recipient.stored.id, xmpp_id, recipient.is_group, origin
94 )
95 if sent is not None:
96 return self.xmpp.LEGACY_MSG_ID_TYPE(sent) # type:ignore[no-any-return]
98 try:
99 return session.xmpp_to_legacy_msg_id(xmpp_id)
100 except XMPPError:
101 raise
102 except Exception as e:
103 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e)
104 raise XMPPError(
105 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID."
106 )
108 async def _get_session_recipient_thread(
109 self, msg: Message
110 ) -> tuple[AnySession, AnyRecipient, int | str | None]:
111 session = await self._get_session(msg)
112 e: AnyRecipient = await get_recipient(session, msg)
113 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e)
114 return session, e, legacy_thread
116 async def _xmpp_to_legacy_thread(
117 self, session: AnySession, msg: Message, recipient: RecipientType
118 ) -> str | int | None:
119 xmpp_thread = msg["thread"]
120 if not xmpp_thread:
121 return None
123 if session.MESSAGE_IDS_ARE_THREAD_IDS:
124 return self._xmpp_msg_id_to_legacy(session, xmpp_thread, recipient) # type:ignore[no-any-return]
126 with session.xmpp.store.session() as orm:
127 legacy_thread_str = session.xmpp.store.id_map.get_thread(
128 orm, recipient.stored.id, xmpp_thread, recipient.is_group
129 )
130 if legacy_thread_str is not None:
131 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str) # type:ignore[no-any-return]
132 async with session.thread_creation_lock:
133 legacy_thread = await recipient.create_thread(xmpp_thread)
134 with session.xmpp.store.session() as orm:
135 session.xmpp.store.id_map.set_thread(
136 orm,
137 recipient.stored.id,
138 str(legacy_thread),
139 xmpp_thread,
140 recipient.is_group,
141 )
142 orm.commit()
143 return legacy_thread
146def _ignore(session: AnySession, msg: Message) -> bool:
147 i = msg.get_id()
148 if i.startswith("slidge-carbon-"):
149 return True
150 if i not in session.ignore_messages:
151 return False
152 session.log.debug("Ignored sent carbon: %s", i)
153 session.ignore_messages.remove(i)
154 return True
157async def get_recipient(session: AnySession, m: Message) -> AnyRecipient:
158 session.raise_if_not_logged()
159 if m.get_type() == "groupchat":
160 muc = await session.bookmarks.by_jid(m.get_to())
161 r = m.get_from().resource
162 if r not in muc.get_user_resources():
163 session.create_task(muc.kick_resource(r))
164 raise XMPPError("not-acceptable", "You are not connected to this chat")
165 return muc # type:ignore[no-any-return]
166 else:
167 return await session.contacts.by_jid(m.get_to()) # type:ignore[no-any-return]
170SelfType = TypeVar("SelfType")
171StanzaType = TypeVar("StanzaType", bound=StanzaBase)
172HandlerType = Callable[[SelfType, StanzaType], Awaitable[None]]
173P = ParamSpec("P")
176def exceptions_to_xmpp_errors(cb: HandlerType[Any, Any]) -> HandlerType[Any, Any]:
177 @wraps(cb)
178 async def wrapped(self: SelfType, stanza: StanzaType) -> None:
179 try:
180 await cb(self, stanza)
181 except Ignore:
182 pass
183 except XMPPError:
184 raise
185 except NotImplementedError:
186 log.debug("NotImplementedError raised in %s", cb)
187 raise XMPPError(
188 "feature-not-implemented",
189 f"{cb.__name__} is not implemented by the legacy module",
190 clear=False,
191 )
192 except ContactIsUser:
193 raise XMPPError(
194 "bad-request", "Actions with your bridged self are not allowed."
195 )
196 except Exception as e:
197 log.error(
198 "Failed to handle incoming stanza: %s - %s", self, stanza, exc_info=e
199 )
200 raise XMPPError("internal-server-error", str(e))
202 return wrapped
205log = logging.getLogger(__name__)