Coverage for slidge / core / dispatcher / util.py: 91%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-06 15:18 +0000

1import logging 

2from functools import wraps 

3from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar 

4 

5from slixmpp import JID, Iq, Message, Presence 

6from slixmpp.exceptions import XMPPError 

7from slixmpp.xmlstream import StanzaBase 

8 

9from ...contact.roster import ContactIsUser 

10from ...util.types import LegacyMessageType, Recipient, RecipientType 

11from ..session import BaseSession 

12 

13if TYPE_CHECKING: 

14 from slidge import BaseGateway 

15 from slidge.group import LegacyMUC 

16 

17 

18class Ignore(BaseException): 

19 pass 

20 

21 

22class DispatcherMixin: 

23 __slots__: list[str] = [] 

24 

25 def __init__(self, xmpp: "BaseGateway") -> None: 

26 self.xmpp = xmpp # type:ignore[misc] 

27 

28 async def _get_session( 

29 self, 

30 stanza: Message | Presence | Iq, 

31 timeout: int | None = 10, 

32 wait_for_ready: bool = True, 

33 logged: bool = False, 

34 ) -> BaseSession: 

35 xmpp = self.xmpp 

36 if stanza.get_from().server == xmpp.boundjid.bare: 

37 log.debug("Ignoring echo") 

38 raise Ignore 

39 if ( 

40 isinstance(stanza, Message) 

41 and stanza.get_type() == "chat" 

42 and stanza.get_to() == xmpp.boundjid.bare 

43 ): 

44 log.debug("Ignoring message to component") 

45 raise Ignore 

46 session = await self._get_session_from_jid( 

47 stanza.get_from(), timeout, wait_for_ready, logged 

48 ) 

49 if isinstance(stanza, Message) and _ignore(session, stanza): 

50 raise Ignore 

51 return session 

52 

53 async def _get_session_from_jid( 

54 self, 

55 jid: JID, 

56 timeout: int | None = 10, 

57 wait_for_ready: bool = True, 

58 logged: bool = False, 

59 ) -> BaseSession: 

60 session = self.xmpp.get_session_from_jid(jid) 

61 if session is None: 

62 raise XMPPError("registration-required") 

63 if logged: 

64 session.raise_if_not_logged() 

65 if wait_for_ready: 

66 await session.wait_for_ready(timeout) 

67 return session 

68 

69 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> "LegacyMUC": 

70 ito = iq.get_to() 

71 if ito == self.xmpp.boundjid.bare: 

72 raise XMPPError("bad-request", text="This is only handled for MUCs") 

73 

74 session = await self._get_session(iq, logged=True) 

75 muc = await session.bookmarks.by_jid(ito) 

76 return muc 

77 

78 def _xmpp_msg_id_to_legacy( 

79 self, 

80 session: "BaseSession[LegacyMessageType, Any]", 

81 xmpp_id: str, 

82 recipient: Recipient, 

83 origin: bool = False, 

84 ) -> LegacyMessageType: 

85 with self.xmpp.store.session() as orm: 

86 sent = self.xmpp.store.id_map.get_legacy( 

87 orm, recipient.stored.id, xmpp_id, recipient.is_group, origin 

88 ) 

89 if sent is not None: 

90 return self.xmpp.LEGACY_MSG_ID_TYPE(sent) 

91 

92 try: 

93 return session.xmpp_to_legacy_msg_id(xmpp_id) 

94 except XMPPError: 

95 raise 

96 except Exception as e: 

97 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e) 

98 raise XMPPError( 

99 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID." 

100 ) 

101 

102 async def _get_session_recipient_thread( 

103 self, msg: Message 

104 ) -> tuple["BaseSession", Recipient, int | str]: 

105 session = await self._get_session(msg) 

106 e: Recipient = await get_recipient(session, msg) 

107 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e) 

108 return session, e, legacy_thread 

109 

110 async def _xmpp_to_legacy_thread( 

111 self, session: "BaseSession", msg: Message, recipient: RecipientType 

112 ): 

113 xmpp_thread = msg["thread"] 

114 if not xmpp_thread: 

115 return None 

116 

117 if session.MESSAGE_IDS_ARE_THREAD_IDS: 

118 return self._xmpp_msg_id_to_legacy(session, xmpp_thread, recipient) 

119 

120 with session.xmpp.store.session() as orm: 

121 legacy_thread_str = session.xmpp.store.id_map.get_thread( 

122 orm, recipient.stored.id, xmpp_thread, recipient.is_group 

123 ) 

124 if legacy_thread_str is not None: 

125 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str) 

126 async with session.thread_creation_lock: 

127 legacy_thread = await recipient.create_thread(xmpp_thread) 

128 with session.xmpp.store.session() as orm: 

129 session.xmpp.store.id_map.set_thread( 

130 orm, 

131 recipient.stored.id, 

132 str(legacy_thread), 

133 xmpp_thread, 

134 recipient.is_group, 

135 ) 

136 orm.commit() 

137 return legacy_thread 

138 

139 

140def _ignore(session: "BaseSession", msg: Message) -> bool: 

141 i = msg.get_id() 

142 if i.startswith("slidge-carbon-"): 

143 return True 

144 if i not in session.ignore_messages: 

145 return False 

146 session.log.debug("Ignored sent carbon: %s", i) 

147 session.ignore_messages.remove(i) 

148 return True 

149 

150 

151async def get_recipient(session: "BaseSession", m: Message) -> RecipientType: 

152 session.raise_if_not_logged() 

153 if m.get_type() == "groupchat": 

154 muc = await session.bookmarks.by_jid(m.get_to()) 

155 r = m.get_from().resource 

156 if r not in muc.get_user_resources(): 

157 session.create_task(muc.kick_resource(r)) 

158 raise XMPPError("not-acceptable", "You are not connected to this chat") 

159 return muc 

160 else: 

161 return await session.contacts.by_jid(m.get_to()) 

162 

163 

164StanzaType = TypeVar("StanzaType", bound=StanzaBase) 

165HandlerType = Callable[[Any, StanzaType], Awaitable[None]] 

166 

167 

168def exceptions_to_xmpp_errors(cb: HandlerType) -> HandlerType: 

169 @wraps(cb) 

170 async def wrapped(*args): 

171 try: 

172 await cb(*args) 

173 except Ignore: 

174 pass 

175 except XMPPError: 

176 raise 

177 except NotImplementedError: 

178 log.debug("NotImplementedError raised in %s", cb) 

179 raise XMPPError( 

180 "feature-not-implemented", 

181 f"{cb.__name__} is not implemented by the legacy module", 

182 clear=False, 

183 ) 

184 except ContactIsUser: 

185 raise XMPPError( 

186 "bad-request", "Actions with your bridged self are not allowed." 

187 ) 

188 except Exception as e: 

189 log.error("Failed to handle incoming stanza: %s", args, exc_info=e) 

190 raise XMPPError("internal-server-error", str(e)) 

191 

192 return wrapped 

193 

194 

195log = logging.getLogger(__name__)