Coverage for slidge / core / dispatcher / util.py: 91%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1import logging 

2from collections.abc import Awaitable, Callable 

3from functools import wraps 

4from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar 

5 

6from slixmpp import JID, Iq, Message, Presence 

7from slixmpp.exceptions import XMPPError 

8from slixmpp.xmlstream import StanzaBase 

9 

10from ...contact.roster import ContactIsUser 

11from ...util.types import ( 

12 AnyMUC, 

13 AnyRecipient, 

14 AnySession, 

15 LegacyMessageType, 

16 RecipientType, 

17) 

18from ..session import BaseSession 

19 

20if TYPE_CHECKING: 

21 from slidge import BaseGateway 

22 

23 

24class Ignore(BaseException): 

25 pass 

26 

27 

28class DispatcherMixin: 

29 __slots__: list[str] = [] 

30 

31 def __init__(self, xmpp: "BaseGateway") -> None: 

32 self.xmpp = xmpp # type:ignore[misc] 

33 

34 async def _get_session( 

35 self, 

36 stanza: Message | Presence | Iq, 

37 timeout: int | None = 10, 

38 wait_for_ready: bool = True, 

39 logged: bool = False, 

40 ) -> AnySession: 

41 xmpp = self.xmpp 

42 if stanza.get_from().server == xmpp.boundjid.bare: 

43 log.debug("Ignoring echo") 

44 raise Ignore 

45 if ( 

46 isinstance(stanza, Message) 

47 and stanza.get_type() == "chat" 

48 and stanza.get_to() == xmpp.boundjid.bare 

49 ): 

50 log.debug("Ignoring message to component") 

51 raise Ignore 

52 session = await self._get_session_from_jid( 

53 stanza.get_from(), timeout, wait_for_ready, logged 

54 ) 

55 if isinstance(stanza, Message) and _ignore(session, stanza): 

56 raise Ignore 

57 return session 

58 

59 async def _get_session_from_jid( 

60 self, 

61 jid: JID, 

62 timeout: int | None = 10, 

63 wait_for_ready: bool = True, 

64 logged: bool = False, 

65 ) -> AnySession: 

66 session = self.xmpp.get_session_from_jid(jid) 

67 if session is None: 

68 raise XMPPError("registration-required") 

69 if logged: 

70 session.raise_if_not_logged() 

71 if wait_for_ready: 

72 await session.wait_for_ready(timeout) 

73 return session 

74 

75 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> AnyMUC: 

76 ito = iq.get_to() 

77 if ito == self.xmpp.boundjid.bare: 

78 raise XMPPError("bad-request", text="This is only handled for MUCs") 

79 

80 session = await self._get_session(iq, logged=True) 

81 muc = await session.bookmarks.by_jid(ito) 

82 return muc # type:ignore[no-any-return] 

83 

84 def _xmpp_msg_id_to_legacy( 

85 self, 

86 session: "BaseSession[LegacyMessageType, Any]", 

87 xmpp_id: str, 

88 recipient: AnyRecipient, 

89 origin: bool = False, 

90 ) -> LegacyMessageType: 

91 with self.xmpp.store.session() as orm: 

92 sent = self.xmpp.store.id_map.get_legacy( 

93 orm, recipient.stored.id, xmpp_id, recipient.is_group, origin 

94 ) 

95 if sent is not None: 

96 return self.xmpp.LEGACY_MSG_ID_TYPE(sent) # type:ignore[no-any-return] 

97 

98 try: 

99 return session.xmpp_to_legacy_msg_id(xmpp_id) 

100 except XMPPError: 

101 raise 

102 except Exception as e: 

103 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e) 

104 raise XMPPError( 

105 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID." 

106 ) 

107 

108 async def _get_session_recipient_thread( 

109 self, msg: Message 

110 ) -> tuple[AnySession, AnyRecipient, int | str | None]: 

111 session = await self._get_session(msg) 

112 e: AnyRecipient = await get_recipient(session, msg) 

113 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e) 

114 return session, e, legacy_thread 

115 

116 async def _xmpp_to_legacy_thread( 

117 self, session: AnySession, msg: Message, recipient: RecipientType 

118 ) -> str | int | None: 

119 xmpp_thread = msg["thread"] 

120 if not xmpp_thread: 

121 return None 

122 

123 if session.MESSAGE_IDS_ARE_THREAD_IDS: 

124 return self._xmpp_msg_id_to_legacy(session, xmpp_thread, recipient) # type:ignore[no-any-return] 

125 

126 with session.xmpp.store.session() as orm: 

127 legacy_thread_str = session.xmpp.store.id_map.get_thread( 

128 orm, recipient.stored.id, xmpp_thread, recipient.is_group 

129 ) 

130 if legacy_thread_str is not None: 

131 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str) # type:ignore[no-any-return] 

132 async with session.thread_creation_lock: 

133 legacy_thread = await recipient.create_thread(xmpp_thread) 

134 with session.xmpp.store.session() as orm: 

135 session.xmpp.store.id_map.set_thread( 

136 orm, 

137 recipient.stored.id, 

138 str(legacy_thread), 

139 xmpp_thread, 

140 recipient.is_group, 

141 ) 

142 orm.commit() 

143 return legacy_thread 

144 

145 

146def _ignore(session: AnySession, msg: Message) -> bool: 

147 i = msg.get_id() 

148 if i.startswith("slidge-carbon-"): 

149 return True 

150 if i not in session.ignore_messages: 

151 return False 

152 session.log.debug("Ignored sent carbon: %s", i) 

153 session.ignore_messages.remove(i) 

154 return True 

155 

156 

157async def get_recipient(session: AnySession, m: Message) -> AnyRecipient: 

158 session.raise_if_not_logged() 

159 if m.get_type() == "groupchat": 

160 muc = await session.bookmarks.by_jid(m.get_to()) 

161 r = m.get_from().resource 

162 if r not in muc.get_user_resources(): 

163 session.create_task(muc.kick_resource(r)) 

164 raise XMPPError("not-acceptable", "You are not connected to this chat") 

165 return muc # type:ignore[no-any-return] 

166 else: 

167 return await session.contacts.by_jid(m.get_to()) # type:ignore[no-any-return] 

168 

169 

170SelfType = TypeVar("SelfType") 

171StanzaType = TypeVar("StanzaType", bound=StanzaBase) 

172HandlerType = Callable[[SelfType, StanzaType], Awaitable[None]] 

173P = ParamSpec("P") 

174 

175 

176def exceptions_to_xmpp_errors(cb: HandlerType[Any, Any]) -> HandlerType[Any, Any]: 

177 @wraps(cb) 

178 async def wrapped(self: SelfType, stanza: StanzaType) -> None: 

179 try: 

180 await cb(self, stanza) 

181 except Ignore: 

182 pass 

183 except XMPPError: 

184 raise 

185 except NotImplementedError: 

186 log.debug("NotImplementedError raised in %s", cb) 

187 raise XMPPError( 

188 "feature-not-implemented", 

189 f"{cb.__name__} is not implemented by the legacy module", 

190 clear=False, 

191 ) 

192 except ContactIsUser: 

193 raise XMPPError( 

194 "bad-request", "Actions with your bridged self are not allowed." 

195 ) 

196 except Exception as e: 

197 log.error( 

198 "Failed to handle incoming stanza: %s - %s", self, stanza, exc_info=e 

199 ) 

200 raise XMPPError("internal-server-error", str(e)) 

201 

202 return wrapped 

203 

204 

205log = logging.getLogger(__name__)