Coverage for slidge / core / dispatcher / disco.py: 92%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-20 19:56 +0000

1import logging 

2from typing import TYPE_CHECKING, Any 

3 

4import sqlalchemy as sa 

5import sqlalchemy.orm 

6from slixmpp import JID 

7from slixmpp.exceptions import XMPPError 

8from slixmpp.plugins.xep_0004.stanza import Form # type:ignore[attr-defined] 

9from slixmpp.plugins.xep_0030.stanza.info import DiscoInfo 

10from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

11from slixmpp.types import OptJid 

12 

13from ...db.models import Room, Space 

14from ...util.types import AnySession 

15from .util import DispatcherMixin 

16 

17if TYPE_CHECKING: 

18 from slidge.util.types import AnyGateway, AnySession 

19 

20 

21class DiscoMixin(DispatcherMixin): 

22 __slots__: list[str] = [] 

23 

24 def __init__(self, xmpp: "AnyGateway") -> None: 

25 super().__init__(xmpp) 

26 

27 xmpp.plugin["xep_0030"].set_node_handler( 

28 "get_info", 

29 jid=None, 

30 node=None, 

31 handler=self.get_info, 

32 ) 

33 

34 xmpp.plugin["xep_0030"].set_node_handler( 

35 "get_items", 

36 jid=None, 

37 node=None, 

38 handler=self.get_items, 

39 ) 

40 

41 async def get_info( 

42 self, 

43 jid: OptJid, 

44 node: str | None, 

45 ifrom: OptJid, 

46 data: Any, # noqa:ANN401 

47 ) -> DiscoInfo | None: 

48 # TODO: OMG the control flow in this method 🤮 

49 if ifrom == self.xmpp.boundjid.bare or ( 

50 jid in (self.xmpp.boundjid.bare, None) 

51 and (not node or node.startswith("http://slixmpp.com/ver/")) 

52 ): 

53 return self.xmpp.plugin["xep_0030"].static.get_info(jid, node, ifrom, data) 

54 

55 if ifrom is None: 

56 raise XMPPError("subscription-required") 

57 

58 assert jid is not None 

59 session = await self._get_session_from_jid(jid=ifrom) 

60 

61 if not jid.username: 

62 assert node is not None # 🤮 

63 return await self.__spaces_info(session, node) 

64 

65 log.debug("Looking for entity: %s", jid) 

66 

67 entity = await session.get_contact_or_group_or_participant(jid) 

68 

69 if entity is None: 

70 raise XMPPError("item-not-found") 

71 

72 return await entity.get_disco_info(jid, node) 

73 

74 async def __spaces_info(self, session: AnySession, node: str) -> DiscoInfo: 

75 legacy_id = await session.bookmarks.space_node_to_legacy_id(node) 

76 with self.xmpp.store.session(expire_on_commit=False) as orm: 

77 space = self.xmpp.store.spaces.get_by_legacy_id( 

78 orm, session.user_pk, legacy_id, full=True 

79 ) 

80 if space is None: 

81 raise XMPPError("item-not-found", f"No space for node '{node}'") 

82 space = await session.bookmarks.update_space_if_needed(space) 

83 form = Form() 

84 form["type"] = "result" 

85 form.add_field( 

86 var="FORM_TYPE", 

87 ftype="hidden", 

88 value="http://jabber.org/protocol/pubsub#meta-data", 

89 ) 

90 form.add_field( 

91 var="pubsub#type", 

92 value="urn:xmpp:spaces:0", 

93 ) 

94 form.add_field( 

95 var="pubsub#creator", 

96 type="jid-single", 

97 value=self.xmpp.boundjid.bare 

98 if space.creator is None 

99 else str(space.creator.jid), 

100 ) 

101 form.add_field(var="pubsub#title", value=space.name) 

102 form.add_field( 

103 var="pubsub#owner", 

104 type="jid-multi", 

105 value=[str(owner.jid) for owner in space.owners], 

106 ) 

107 if space.member_count is not None: 

108 form.add_field(var="pubsub#num_subscribers", value=str(space.member_count)) 

109 if space.description is not None: 

110 form.add_field(var="pubsub#description", value=space.description) 

111 info = DiscoInfo() 

112 info.add_identity(category="pubsub", itype="leaf") 

113 info.add_feature("http://jabber.org/protocol/pubsub") 

114 for feat in ( 

115 "meta-data", 

116 "item-ids", 

117 "manage-subscriptions", 

118 "modify-affiliations", 

119 "outcast-affiliation", 

120 "retract-items", 

121 "retrieve-affiliations", 

122 "retrieve-items", 

123 "retrieve-subscriptions", 

124 "subscribe", 

125 "subscription-notifications", 

126 ): 

127 info.add_feature(f"http://jabber.org/protocol/pubsub#{feat}") 

128 info.append(form) 

129 return info 

130 

131 async def get_items( 

132 self, 

133 jid: OptJid, 

134 node: str | None, 

135 ifrom: OptJid, 

136 data: Any, # noqa:ANN401 

137 ) -> DiscoItems: 

138 if ifrom is None: 

139 raise XMPPError("bad-request") 

140 

141 assert ifrom is not None 

142 session = await self._get_session_from_jid(ifrom) 

143 

144 if jid == self.xmpp.boundjid.bare or not jid: 

145 if node: 

146 return DiscoItems() 

147 else: 

148 return await self.__list_rooms(session, data) 

149 

150 entity = await session.get_contact_or_group_or_participant(jid) 

151 

152 if entity is None: 

153 raise XMPPError("item-not-found") 

154 

155 return await entity.get_disco_items(node) 

156 

157 async def __list_rooms( 

158 self, 

159 session: "AnySession", 

160 data: dict, # type:ignore[type-arg] 

161 ) -> DiscoItems: 

162 try: 

163 included_types = data["disco_items"]["filter"]["included_types"] 

164 except KeyError: 

165 included_types = [] 

166 

167 d = DiscoItems() 

168 with self.xmpp.store.session() as orm: 

169 if included_types == ["urn:xmpp:spaces:0"]: 

170 if not self.xmpp.SPACES: 

171 raise XMPPError( 

172 "feature-not-implemented", 

173 "This gateway does not support XEP-0503", 

174 ) 

175 await self.__spaces(orm, session, d) 

176 elif not included_types: 

177 self.__rooms(orm, session, d) 

178 else: 

179 raise XMPPError("item-not-found", f"No items for {included_types}") 

180 

181 return d 

182 

183 async def __spaces( 

184 self, orm: sqlalchemy.orm.Session, session: AnySession, d: DiscoItems 

185 ) -> None: 

186 await session.bookmarks.update_spaces_info() 

187 for space in orm.execute( 

188 sa.select(Space) 

189 .options(sa.orm.load_only(Space.legacy_id, Space.name)) 

190 .filter_by(user=session.user) 

191 ).scalars(): 

192 d.add_item( 

193 JID(self.xmpp.boundjid.bare), 

194 name=space.name, 

195 node=await session.bookmarks.space_legacy_id_to_node(space.legacy_id), 

196 ) 

197 

198 @staticmethod 

199 def __rooms( 

200 orm: sqlalchemy.orm.Session, session: AnySession, d: DiscoItems 

201 ) -> None: 

202 for room in orm.execute( 

203 sa.select(Room) 

204 .options(sa.orm.load_only(Room.jid, Room.name)) 

205 .filter_by(user=session.user) 

206 .order_by(Room.name) 

207 ).scalars(): 

208 d.add_item(room.jid, name=room.name) 

209 

210 

211log = logging.getLogger(__name__)