Coverage for slidge / core / dispatcher / util.py: 95%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 04:38 +0000

1import logging 

2from collections.abc import Awaitable, Callable 

3from functools import wraps 

4from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar 

5 

6from slixmpp import JID, Iq, Message, Presence 

7from slixmpp.exceptions import XMPPError 

8from slixmpp.xmlstream import StanzaBase 

9 

10from ...contact.roster import ContactIsUser 

11from ...util.types import ( 

12 AnyMUC, 

13 AnyRecipient, 

14 AnySession, 

15 RecipientType, 

16) 

17 

18if TYPE_CHECKING: 

19 from slidge.util.types import AnyGateway 

20 

21 

22class Ignore(BaseException): 

23 pass 

24 

25 

26class DispatcherMixin: 

27 __slots__: list[str] = [] 

28 

29 def __init__(self, xmpp: "AnyGateway") -> None: 

30 self.xmpp = xmpp # type:ignore[misc] 

31 

32 async def _get_session( 

33 self, 

34 stanza: Message | Presence | Iq, 

35 timeout: int | None = 10, 

36 wait_for_ready: bool = True, 

37 logged: bool = False, 

38 ) -> AnySession: 

39 xmpp = self.xmpp 

40 if stanza.get_from().server == xmpp.boundjid.bare: 

41 log.debug("Ignoring echo") 

42 raise Ignore 

43 if ( 

44 isinstance(stanza, Message) 

45 and stanza.get_type() == "chat" 

46 and stanza.get_to() == xmpp.boundjid.bare 

47 ): 

48 log.debug("Ignoring message to component") 

49 raise Ignore 

50 session = await self._get_session_from_jid( 

51 stanza.get_from(), timeout, wait_for_ready, logged 

52 ) 

53 if isinstance(stanza, Message) and _ignore(session, stanza): 

54 raise Ignore 

55 return session 

56 

57 async def _get_session_from_jid( 

58 self, 

59 jid: JID, 

60 timeout: int | None = 10, 

61 wait_for_ready: bool = True, 

62 logged: bool = False, 

63 ) -> AnySession: 

64 session = self.xmpp.get_session_from_jid(jid) 

65 if session is None: 

66 raise XMPPError("registration-required") 

67 if logged: 

68 session.raise_if_not_logged() 

69 if wait_for_ready: 

70 await session.wait_for_ready(timeout) 

71 return session 

72 

73 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> AnyMUC: 

74 ito = iq.get_to() 

75 if ito == self.xmpp.boundjid.bare: 

76 raise XMPPError("bad-request", text="This is only handled for MUCs") 

77 

78 session = await self._get_session(iq, logged=True) 

79 muc = await session.bookmarks.by_jid(ito) 

80 return muc # type:ignore[no-any-return] 

81 

82 def _xmpp_msg_id_to_legacy( 

83 self, 

84 xmpp_id: str, 

85 recipient: AnyRecipient, 

86 origin: bool = False, 

87 ) -> str: 

88 with self.xmpp.store.session() as orm: 

89 sent = self.xmpp.store.id_map.get_legacy( 

90 orm, recipient.stored.id, xmpp_id, recipient.is_group, origin 

91 ) 

92 if sent is not None: 

93 return sent 

94 

95 return xmpp_id 

96 

97 async def _get_recipient_and_thread( 

98 self, msg: Message 

99 ) -> tuple[AnyRecipient, str | None]: 

100 session = await self._get_session(msg) 

101 e: AnyRecipient = await get_recipient(session, msg) 

102 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e) 

103 return e, legacy_thread 

104 

105 async def _xmpp_to_legacy_thread( 

106 self, session: AnySession, msg: Message, recipient: RecipientType 

107 ) -> str | None: 

108 xmpp_thread = msg["thread"] 

109 if not xmpp_thread: 

110 return None 

111 

112 if session.MESSAGE_IDS_ARE_THREAD_IDS: 

113 return self._xmpp_msg_id_to_legacy(xmpp_thread, recipient) 

114 

115 with session.xmpp.store.session() as orm: 

116 legacy_thread_str = session.xmpp.store.id_map.get_thread( 

117 orm, recipient.stored.id, xmpp_thread, recipient.is_group 

118 ) 

119 if legacy_thread_str is not None: 

120 return legacy_thread_str 

121 async with session.thread_creation_lock: 

122 legacy_thread = await recipient.create_thread(xmpp_thread) 

123 with session.xmpp.store.session() as orm: 

124 session.xmpp.store.id_map.set_thread( 

125 orm, 

126 recipient.stored.id, 

127 str(legacy_thread), 

128 xmpp_thread, 

129 recipient.is_group, 

130 ) 

131 orm.commit() 

132 return legacy_thread 

133 

134 

135def _ignore(session: AnySession, msg: Message) -> bool: 

136 i = msg.get_id() 

137 if i.startswith("slidge-carbon-"): 

138 return True 

139 if i not in session.ignore_messages: 

140 return False 

141 session.log.debug("Ignored sent carbon: %s", i) 

142 session.ignore_messages.remove(i) 

143 return True 

144 

145 

146async def get_recipient(session: AnySession, m: Message) -> AnyRecipient: 

147 session.raise_if_not_logged() 

148 if m.get_type() == "groupchat": 

149 muc = await session.bookmarks.by_jid(m.get_to()) 

150 r = m.get_from().resource 

151 if r not in muc.get_user_resources(): 

152 session.create_task(muc.kick_resource(r)) 

153 raise XMPPError("not-acceptable", "You are not connected to this chat") 

154 return muc # type:ignore[no-any-return] 

155 else: 

156 return await session.contacts.by_jid(m.get_to()) # type:ignore[no-any-return] 

157 

158 

159SelfType = TypeVar("SelfType") 

160StanzaType = TypeVar("StanzaType", bound=StanzaBase) 

161HandlerType = Callable[[SelfType, StanzaType], Awaitable[None]] 

162P = ParamSpec("P") 

163 

164 

165def exceptions_to_xmpp_errors(cb: HandlerType[Any, Any]) -> HandlerType[Any, Any]: 

166 @wraps(cb) 

167 async def wrapped(self: SelfType, stanza: StanzaType) -> None: 

168 try: 

169 await cb(self, stanza) 

170 except Ignore: 

171 pass 

172 except XMPPError: 

173 raise 

174 except NotImplementedError: 

175 log.debug("NotImplementedError raised in %s", cb) 

176 raise XMPPError( 

177 "feature-not-implemented", 

178 f"{cb.__name__} is not implemented by the legacy module", 

179 clear=False, 

180 ) 

181 except ContactIsUser: 

182 raise XMPPError( 

183 "bad-request", "Actions with your bridged self are not allowed." 

184 ) 

185 except Exception as e: 

186 log.error( 

187 "Failed to handle incoming stanza: %s - %s", self, stanza, exc_info=e 

188 ) 

189 raise XMPPError("internal-server-error", str(e)) 

190 

191 return wrapped 

192 

193 

194log = logging.getLogger(__name__)