Coverage for slidge / core / dispatcher / util.py: 91%

117 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-02-15 09:02 +0000

1import logging 

2from collections.abc import Awaitable, Callable 

3from functools import wraps 

4from typing import TYPE_CHECKING, Any, TypeVar 

5 

6from slixmpp import JID, Iq, Message, Presence 

7from slixmpp.exceptions import XMPPError 

8from slixmpp.xmlstream import StanzaBase 

9 

10from ...contact.roster import ContactIsUser 

11from ...util.types import LegacyMessageType, Recipient, RecipientType 

12from ..session import BaseSession 

13 

14if TYPE_CHECKING: 

15 from slidge import BaseGateway 

16 from slidge.group import LegacyMUC 

17 

18 

19class Ignore(BaseException): 

20 pass 

21 

22 

23class DispatcherMixin: 

24 __slots__: list[str] = [] 

25 

26 def __init__(self, xmpp: "BaseGateway") -> None: 

27 self.xmpp = xmpp # type:ignore[misc] 

28 

29 async def _get_session( 

30 self, 

31 stanza: Message | Presence | Iq, 

32 timeout: int | None = 10, 

33 wait_for_ready: bool = True, 

34 logged: bool = False, 

35 ) -> BaseSession: 

36 xmpp = self.xmpp 

37 if stanza.get_from().server == xmpp.boundjid.bare: 

38 log.debug("Ignoring echo") 

39 raise Ignore 

40 if ( 

41 isinstance(stanza, Message) 

42 and stanza.get_type() == "chat" 

43 and stanza.get_to() == xmpp.boundjid.bare 

44 ): 

45 log.debug("Ignoring message to component") 

46 raise Ignore 

47 session = await self._get_session_from_jid( 

48 stanza.get_from(), timeout, wait_for_ready, logged 

49 ) 

50 if isinstance(stanza, Message) and _ignore(session, stanza): 

51 raise Ignore 

52 return session 

53 

54 async def _get_session_from_jid( 

55 self, 

56 jid: JID, 

57 timeout: int | None = 10, 

58 wait_for_ready: bool = True, 

59 logged: bool = False, 

60 ) -> BaseSession: 

61 session = self.xmpp.get_session_from_jid(jid) 

62 if session is None: 

63 raise XMPPError("registration-required") 

64 if logged: 

65 session.raise_if_not_logged() 

66 if wait_for_ready: 

67 await session.wait_for_ready(timeout) 

68 return session 

69 

70 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> "LegacyMUC": 

71 ito = iq.get_to() 

72 if ito == self.xmpp.boundjid.bare: 

73 raise XMPPError("bad-request", text="This is only handled for MUCs") 

74 

75 session = await self._get_session(iq, logged=True) 

76 muc = await session.bookmarks.by_jid(ito) 

77 return muc 

78 

79 def _xmpp_msg_id_to_legacy( 

80 self, 

81 session: "BaseSession[LegacyMessageType, Any]", 

82 xmpp_id: str, 

83 recipient: Recipient, 

84 origin: bool = False, 

85 ) -> LegacyMessageType: 

86 with self.xmpp.store.session() as orm: 

87 sent = self.xmpp.store.id_map.get_legacy( 

88 orm, recipient.stored.id, xmpp_id, recipient.is_group, origin 

89 ) 

90 if sent is not None: 

91 return self.xmpp.LEGACY_MSG_ID_TYPE(sent) 

92 

93 try: 

94 return session.xmpp_to_legacy_msg_id(xmpp_id) 

95 except XMPPError: 

96 raise 

97 except Exception as e: 

98 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e) 

99 raise XMPPError( 

100 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID." 

101 ) 

102 

103 async def _get_session_recipient_thread( 

104 self, msg: Message 

105 ) -> tuple["BaseSession", Recipient, int | str]: 

106 session = await self._get_session(msg) 

107 e: Recipient = await get_recipient(session, msg) 

108 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e) 

109 return session, e, legacy_thread 

110 

111 async def _xmpp_to_legacy_thread( 

112 self, session: "BaseSession", msg: Message, recipient: RecipientType 

113 ): 

114 xmpp_thread = msg["thread"] 

115 if not xmpp_thread: 

116 return None 

117 

118 if session.MESSAGE_IDS_ARE_THREAD_IDS: 

119 return self._xmpp_msg_id_to_legacy(session, xmpp_thread, recipient) 

120 

121 with session.xmpp.store.session() as orm: 

122 legacy_thread_str = session.xmpp.store.id_map.get_thread( 

123 orm, recipient.stored.id, xmpp_thread, recipient.is_group 

124 ) 

125 if legacy_thread_str is not None: 

126 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str) 

127 async with session.thread_creation_lock: 

128 legacy_thread = await recipient.create_thread(xmpp_thread) 

129 with session.xmpp.store.session() as orm: 

130 session.xmpp.store.id_map.set_thread( 

131 orm, 

132 recipient.stored.id, 

133 str(legacy_thread), 

134 xmpp_thread, 

135 recipient.is_group, 

136 ) 

137 orm.commit() 

138 return legacy_thread 

139 

140 

141def _ignore(session: "BaseSession", msg: Message) -> bool: 

142 i = msg.get_id() 

143 if i.startswith("slidge-carbon-"): 

144 return True 

145 if i not in session.ignore_messages: 

146 return False 

147 session.log.debug("Ignored sent carbon: %s", i) 

148 session.ignore_messages.remove(i) 

149 return True 

150 

151 

152async def get_recipient(session: "BaseSession", m: Message) -> RecipientType: 

153 session.raise_if_not_logged() 

154 if m.get_type() == "groupchat": 

155 muc = await session.bookmarks.by_jid(m.get_to()) 

156 r = m.get_from().resource 

157 if r not in muc.get_user_resources(): 

158 session.create_task(muc.kick_resource(r)) 

159 raise XMPPError("not-acceptable", "You are not connected to this chat") 

160 return muc 

161 else: 

162 return await session.contacts.by_jid(m.get_to()) 

163 

164 

165StanzaType = TypeVar("StanzaType", bound=StanzaBase) 

166HandlerType = Callable[[Any, StanzaType], Awaitable[None]] 

167 

168 

169def exceptions_to_xmpp_errors(cb: HandlerType) -> HandlerType: 

170 @wraps(cb) 

171 async def wrapped(*args): 

172 try: 

173 await cb(*args) 

174 except Ignore: 

175 pass 

176 except XMPPError: 

177 raise 

178 except NotImplementedError: 

179 log.debug("NotImplementedError raised in %s", cb) 

180 raise XMPPError( 

181 "feature-not-implemented", 

182 f"{cb.__name__} is not implemented by the legacy module", 

183 clear=False, 

184 ) 

185 except ContactIsUser: 

186 raise XMPPError( 

187 "bad-request", "Actions with your bridged self are not allowed." 

188 ) 

189 except Exception as e: 

190 log.error("Failed to handle incoming stanza: %s", args, exc_info=e) 

191 raise XMPPError("internal-server-error", str(e)) 

192 

193 return wrapped 

194 

195 

196log = logging.getLogger(__name__)