Coverage for slidge/core/dispatcher/util.py: 89%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import logging 

2from functools import wraps 

3from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar 

4 

5from slixmpp import JID, Iq, Message, Presence 

6from slixmpp.exceptions import XMPPError 

7from slixmpp.xmlstream import StanzaBase 

8 

9from ...util.types import Recipient, RecipientType 

10from ..session import BaseSession 

11 

12if TYPE_CHECKING: 

13 from slidge import BaseGateway 

14 from slidge.group import LegacyMUC 

15 

16 

17class Ignore(BaseException): 

18 pass 

19 

20 

21class DispatcherMixin: 

22 def __init__(self, xmpp: "BaseGateway"): 

23 self.xmpp = xmpp 

24 

25 async def _get_session( 

26 self, 

27 stanza: Message | Presence | Iq, 

28 timeout: int | None = 10, 

29 wait_for_ready=True, 

30 logged=False, 

31 ) -> BaseSession: 

32 xmpp = self.xmpp 

33 if stanza.get_from().server == xmpp.boundjid.bare: 

34 log.debug("Ignoring echo") 

35 raise Ignore 

36 if ( 

37 isinstance(stanza, Message) 

38 and stanza.get_type() == "chat" 

39 and stanza.get_to() == xmpp.boundjid.bare 

40 ): 

41 log.debug("Ignoring message to component") 

42 raise Ignore 

43 session = await self._get_session_from_jid( 

44 stanza.get_from(), timeout, wait_for_ready, logged 

45 ) 

46 if isinstance(stanza, Message) and _ignore(session, stanza): 

47 raise Ignore 

48 return session 

49 

50 async def _get_session_from_jid( 

51 self, 

52 jid: JID, 

53 timeout: int | None = 10, 

54 wait_for_ready=True, 

55 logged=False, 

56 ) -> BaseSession: 

57 session = self.xmpp.get_session_from_jid(jid) 

58 if session is None: 

59 raise XMPPError("registration-required") 

60 if logged: 

61 session.raise_if_not_logged() 

62 if wait_for_ready: 

63 await session.wait_for_ready(timeout) 

64 return session 

65 

66 async def get_muc_from_stanza(self, iq: Iq | Message | Presence) -> "LegacyMUC": 

67 ito = iq.get_to() 

68 if ito == self.xmpp.boundjid.bare: 

69 raise XMPPError("bad-request", text="This is only handled for MUCs") 

70 

71 session = await self._get_session(iq, logged=True) 

72 muc = await session.bookmarks.by_jid(ito) 

73 return muc 

74 

75 def _xmpp_msg_id_to_legacy(self, session: "BaseSession", xmpp_id: str): 

76 sent = self.xmpp.store.sent.get_legacy_id(session.user_pk, xmpp_id) 

77 if sent is not None: 

78 return self.xmpp.LEGACY_MSG_ID_TYPE(sent) 

79 

80 multi = self.xmpp.store.multi.get_legacy_id(session.user_pk, xmpp_id) 

81 if multi: 

82 return self.xmpp.LEGACY_MSG_ID_TYPE(multi) 

83 

84 try: 

85 return session.xmpp_to_legacy_msg_id(xmpp_id) 

86 except XMPPError: 

87 raise 

88 except Exception as e: 

89 log.debug("Couldn't convert xmpp msg ID to legacy ID.", exc_info=e) 

90 raise XMPPError( 

91 "internal-server-error", "Couldn't convert xmpp msg ID to legacy ID." 

92 ) 

93 

94 async def _get_session_entity_thread( 

95 self, msg: Message 

96 ) -> tuple["BaseSession", Recipient, int | str]: 

97 session = await self._get_session(msg) 

98 e: Recipient = await _get_entity(session, msg) 

99 legacy_thread = await self._xmpp_to_legacy_thread(session, msg, e) 

100 return session, e, legacy_thread 

101 

102 async def _xmpp_to_legacy_thread( 

103 self, session: "BaseSession", msg: Message, recipient: RecipientType 

104 ): 

105 xmpp_thread = msg["thread"] 

106 if not xmpp_thread: 

107 return None 

108 

109 if session.MESSAGE_IDS_ARE_THREAD_IDS: 

110 return self._xmpp_msg_id_to_legacy(session, xmpp_thread) 

111 

112 legacy_thread_str = session.xmpp.store.sent.get_legacy_thread( 

113 session.user_pk, xmpp_thread 

114 ) 

115 if legacy_thread_str is not None: 

116 return session.xmpp.LEGACY_MSG_ID_TYPE(legacy_thread_str) 

117 async with session.thread_creation_lock: 

118 legacy_thread = await recipient.create_thread(xmpp_thread) 

119 session.xmpp.store.sent.set_thread( 

120 session.user_pk, str(legacy_thread), xmpp_thread 

121 ) 

122 return legacy_thread 

123 

124 

125def _ignore(session: "BaseSession", msg: Message): 

126 i = msg.get_id() 

127 if i.startswith("slidge-carbon-"): 

128 return True 

129 if i not in session.ignore_messages: 

130 return False 

131 session.log.debug("Ignored sent carbon: %s", i) 

132 session.ignore_messages.remove(i) 

133 return True 

134 

135 

136async def _get_entity(session: "BaseSession", m: Message) -> RecipientType: 

137 session.raise_if_not_logged() 

138 if m.get_type() == "groupchat": 

139 muc = await session.bookmarks.by_jid(m.get_to()) 

140 r = m.get_from().resource 

141 if r not in muc.get_user_resources(): 

142 session.create_task(muc.kick_resource(r)) 

143 raise XMPPError("not-acceptable", "You are not connected to this chat") 

144 return muc 

145 else: 

146 return await session.contacts.by_jid(m.get_to()) 

147 

148 

149StanzaType = TypeVar("StanzaType", bound=StanzaBase) 

150HandlerType = Callable[[Any, StanzaType], Awaitable[None]] 

151 

152 

153def exceptions_to_xmpp_errors(cb: HandlerType) -> HandlerType: 

154 @wraps(cb) 

155 async def wrapped(*args): 

156 try: 

157 await cb(*args) 

158 except Ignore: 

159 pass 

160 except XMPPError: 

161 raise 

162 except NotImplementedError: 

163 log.debug("NotImplementedError raised in %s", cb) 

164 raise XMPPError( 

165 "feature-not-implemented", "Not implemented by the legacy module" 

166 ) 

167 except Exception as e: 

168 log.error("Failed to handle incoming stanza: %s", args, exc_info=e) 

169 raise XMPPError("internal-server-error", str(e)) 

170 

171 return wrapped 

172 

173 

174log = logging.getLogger(__name__)