Coverage for slidge/core/dispatcher/util.py: 89%
111 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1import logging
2from functools import wraps
3from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar
5from slixmpp import JID, Iq, Message, Presence
6from slixmpp.exceptions import XMPPError
7from slixmpp.xmlstream import StanzaBase
9from ...util.types import Recipient, RecipientType
10from ..session import BaseSession
12if TYPE_CHECKING:
13 from slidge import BaseGateway
14 from slidge.group import LegacyMUC
17class Ignore(BaseException):
18 pass
21class DispatcherMixin:
22 def __init__(self, xmpp: "BaseGateway"):
23 self.xmpp = xmpp
25 async def _get_session(
26 self,
27 stanza: Message | Presence | Iq,
28 timeout: int | None = 10,
29 wait_for_ready=True,
30 logged=False,
31 ) -> BaseSession:
32 xmpp = self.xmpp
33 if stanza.get_from().server == xmpp.boundjid.bare:
34 log.debug("Ignoring echo")
35 raise Ignore
36 if (
37 isinstance(stanza, Message)
38 and stanza.get_type() == "chat"
39 and stanza.get_to() == xmpp.boundjid.bare
40 ):
41 log.debug("Ignoring message to component")
42 raise Ignore
43 session = await self._get_session_from_jid(
44 stanza.get_from(), timeout, wait_for_ready, logged
45 )
46 if isinstance(stanza, Message) and _ignore(session, stanza):
47 raise Ignore
48 return session
50 async def _get_session_from_jid(
51 self,
52 jid: JID,
53 timeout: int | None = 10,
54 wait_for_ready=True,
55 logged=False,
56 ) -> BaseSession:
57 session = self.xmpp.get_session_from_jid(jid)
58 if session is None:
59 raise XMPPError("registration-required")
60 if logged:
61 session.raise_if_not_logged()
62 if wait_for_ready:
63 await session.wait_for_ready(timeout)
64 return session
66 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> "LegacyMUC":
67 ito = iq.get_to()
68 if ito == self.xmpp.boundjid.bare:
69 raise XMPPError("bad-request", text="This is only handled for MUCs")
71 session = await self._get_session(iq, logged=True)
72 muc = await session.bookmarks.by_jid(ito)
73 return muc
75 def _xmpp_msg_id_to_legacy(self, session: "BaseSession", xmpp_id: str):
76 sent = self.xmpp.store.sent.get_legacy_id(session.user_pk, xmpp_id)
77 if sent is not None:
78 return self.xmpp.LEGACY_MSG_ID_TYPE(sent)
80 multi = self.xmpp.store.multi.get_legacy_id(session.user_pk, xmpp_id)
81 if multi:
82 return self.xmpp.LEGACY_MSG_ID_TYPE(multi)
84 try:
85 return session.xmpp_to_legacy_msg_id(xmpp_id)
86 except XMPPError:
87 raise
88 except Exception as e:
89 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e)
90 raise XMPPError(
91 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID."
92 )
94 async def _get_session_entity_thread(
95 self, msg: Message
96 ) -> tuple["BaseSession", Recipient, int | str]:
97 session = await self._get_session(msg)
98 e: Recipient = await _get_entity(session, msg)
99 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e)
100 return session, e, legacy_thread
102 async def _xmpp_to_legacy_thread(
103 self, session: "BaseSession", msg: Message, recipient: RecipientType
104 ):
105 xmpp_thread = msg["thread"]
106 if not xmpp_thread:
107 return None
109 if session.MESSAGE_IDS_ARE_THREAD_IDS:
110 return self._xmpp_msg_id_to_legacy(session, xmpp_thread)
112 legacy_thread_str = session.xmpp.store.sent.get_legacy_thread(
113 session.user_pk, xmpp_thread
114 )
115 if legacy_thread_str is not None:
116 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str)
117 async with session.thread_creation_lock:
118 legacy_thread = await recipient.create_thread(xmpp_thread)
119 session.xmpp.store.sent.set_thread(
120 session.user_pk, str(legacy_thread), xmpp_thread
121 )
122 return legacy_thread
125def _ignore(session: "BaseSession", msg: Message):
126 i = msg.get_id()
127 if i.startswith("slidge-carbon-"):
128 return True
129 if i not in session.ignore_messages:
130 return False
131 session.log.debug("Ignored sent carbon: %s", i)
132 session.ignore_messages.remove(i)
133 return True
136async def _get_entity(session: "BaseSession", m: Message) -> RecipientType:
137 session.raise_if_not_logged()
138 if m.get_type() == "groupchat":
139 muc = await session.bookmarks.by_jid(m.get_to())
140 r = m.get_from().resource
141 if r not in muc.get_user_resources():
142 session.create_task(muc.kick_resource(r))
143 raise XMPPError("not-acceptable", "You are not connected to this chat")
144 return muc
145 else:
146 return await session.contacts.by_jid(m.get_to())
149StanzaType = TypeVar("StanzaType", bound=StanzaBase)
150HandlerType = Callable[[Any, StanzaType], Awaitable[None]]
153def exceptions_to_xmpp_errors(cb: HandlerType) -> HandlerType:
154 @wraps(cb)
155 async def wrapped(*args):
156 try:
157 await cb(*args)
158 except Ignore:
159 pass
160 except XMPPError:
161 raise
162 except NotImplementedError:
163 log.debug("NotImplementedError raised in %s", cb)
164 raise XMPPError(
165 "feature-not-implemented", "Not implemented by the legacy module"
166 )
167 except Exception as e:
168 log.error("Failed to handle incoming stanza: %s", args, exc_info=e)
169 raise XMPPError("internal-server-error", str(e))
171 return wrapped
174log = logging.getLogger(__name__)