Coverage for slidge/core/dispatcher/util.py: 91%

113 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-26 19:34 +0000

1import logging 

2from functools import wraps 

3from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar 

4 

5from slixmpp import JID, Iq, Message, Presence 

6from slixmpp.exceptions import XMPPError 

7from slixmpp.xmlstream import StanzaBase 

8 

9from ...util.types import LegacyMessageType, Recipient, RecipientType 

10from ..session import BaseSession 

11 

12if TYPE_CHECKING: 

13 from slidge import BaseGateway 

14 from slidge.group import LegacyMUC 

15 

16 

17class Ignore(BaseException): 

18 pass 

19 

20 

21class DispatcherMixin: 

22 __slots__: list[str] = [] 

23 

24 def __init__(self, xmpp: "BaseGateway") -> None: 

25 self.xmpp = xmpp # type:ignore[misc] 

26 

27 async def _get_session( 

28 self, 

29 stanza: Message | Presence | Iq, 

30 timeout: int | None = 10, 

31 wait_for_ready: bool = True, 

32 logged: bool = False, 

33 ) -> BaseSession: 

34 xmpp = self.xmpp 

35 if stanza.get_from().server == xmpp.boundjid.bare: 

36 log.debug("Ignoring echo") 

37 raise Ignore 

38 if ( 

39 isinstance(stanza, Message) 

40 and stanza.get_type() == "chat" 

41 and stanza.get_to() == xmpp.boundjid.bare 

42 ): 

43 log.debug("Ignoring message to component") 

44 raise Ignore 

45 session = await self._get_session_from_jid( 

46 stanza.get_from(), timeout, wait_for_ready, logged 

47 ) 

48 if isinstance(stanza, Message) and _ignore(session, stanza): 

49 raise Ignore 

50 return session 

51 

52 async def _get_session_from_jid( 

53 self, 

54 jid: JID, 

55 timeout: int | None = 10, 

56 wait_for_ready: bool = True, 

57 logged: bool = False, 

58 ) -> BaseSession: 

59 session = self.xmpp.get_session_from_jid(jid) 

60 if session is None: 

61 raise XMPPError("registration-required") 

62 if logged: 

63 session.raise_if_not_logged() 

64 if wait_for_ready: 

65 await session.wait_for_ready(timeout) 

66 return session 

67 

68 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> "LegacyMUC": 

69 ito = iq.get_to() 

70 if ito == self.xmpp.boundjid.bare: 

71 raise XMPPError("bad-request", text="This is only handled for MUCs") 

72 

73 session = await self._get_session(iq, logged=True) 

74 muc = await session.bookmarks.by_jid(ito) 

75 return muc 

76 

77 def _xmpp_msg_id_to_legacy( 

78 self, 

79 session: "BaseSession[LegacyMessageType, Any]", 

80 xmpp_id: str, 

81 recipient: Recipient, 

82 origin: bool = False, 

83 ) -> LegacyMessageType: 

84 with self.xmpp.store.session() as orm: 

85 sent = self.xmpp.store.id_map.get_legacy( 

86 orm, recipient.stored.id, xmpp_id, recipient.is_group, origin 

87 ) 

88 if sent is not None: 

89 return self.xmpp.LEGACY_MSG_ID_TYPE(sent) 

90 

91 try: 

92 return session.xmpp_to_legacy_msg_id(xmpp_id) 

93 except XMPPError: 

94 raise 

95 except Exception as e: 

96 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e) 

97 raise XMPPError( 

98 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID." 

99 ) 

100 

101 async def _get_session_recipient_thread( 

102 self, msg: Message 

103 ) -> tuple["BaseSession", Recipient, int | str]: 

104 session = await self._get_session(msg) 

105 e: Recipient = await get_recipient(session, msg) 

106 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e) 

107 return session, e, legacy_thread 

108 

109 async def _xmpp_to_legacy_thread( 

110 self, session: "BaseSession", msg: Message, recipient: RecipientType 

111 ): 

112 xmpp_thread = msg["thread"] 

113 if not xmpp_thread: 

114 return None 

115 

116 if session.MESSAGE_IDS_ARE_THREAD_IDS: 

117 return self._xmpp_msg_id_to_legacy(session, xmpp_thread, recipient) 

118 

119 with session.xmpp.store.session() as orm: 

120 legacy_thread_str = session.xmpp.store.id_map.get_thread( 

121 orm, recipient.stored.id, xmpp_thread, recipient.is_group 

122 ) 

123 if legacy_thread_str is not None: 

124 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str) 

125 async with session.thread_creation_lock: 

126 legacy_thread = await recipient.create_thread(xmpp_thread) 

127 with session.xmpp.store.session() as orm: 

128 session.xmpp.store.id_map.set_thread( 

129 orm, 

130 recipient.stored.id, 

131 str(legacy_thread), 

132 xmpp_thread, 

133 recipient.is_group, 

134 ) 

135 orm.commit() 

136 return legacy_thread 

137 

138 

139def _ignore(session: "BaseSession", msg: Message) -> bool: 

140 i = msg.get_id() 

141 if i.startswith("slidge-carbon-"): 

142 return True 

143 if i not in session.ignore_messages: 

144 return False 

145 session.log.debug("Ignored sent carbon: %s", i) 

146 session.ignore_messages.remove(i) 

147 return True 

148 

149 

150async def get_recipient(session: "BaseSession", m: Message) -> RecipientType: 

151 session.raise_if_not_logged() 

152 if m.get_type() == "groupchat": 

153 muc = await session.bookmarks.by_jid(m.get_to()) 

154 r = m.get_from().resource 

155 if r not in muc.get_user_resources(): 

156 session.create_task(muc.kick_resource(r)) 

157 raise XMPPError("not-acceptable", "You are not connected to this chat") 

158 return muc 

159 else: 

160 return await session.contacts.by_jid(m.get_to()) 

161 

162 

163StanzaType = TypeVar("StanzaType", bound=StanzaBase) 

164HandlerType = Callable[[Any, StanzaType], Awaitable[None]] 

165 

166 

167def exceptions_to_xmpp_errors(cb: HandlerType) -> HandlerType: 

168 @wraps(cb) 

169 async def wrapped(*args): 

170 try: 

171 await cb(*args) 

172 except Ignore: 

173 pass 

174 except XMPPError: 

175 raise 

176 except NotImplementedError: 

177 log.debug("NotImplementedError raised in %s", cb) 

178 raise XMPPError( 

179 "feature-not-implemented", 

180 f"{cb.__name__} is not implemented by the legacy module", 

181 clear=False, 

182 ) 

183 except Exception as e: 

184 log.error("Failed to handle incoming stanza: %s", args, exc_info=e) 

185 raise XMPPError("internal-server-error", str(e)) 

186 

187 return wrapped 

188 

189 

190log = logging.getLogger(__name__)