Coverage for slidge / contact / roster.py: 75%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 04:38 +0000

1import asyncio 

2import logging 

3import warnings 

4from collections.abc import AsyncIterator, Iterator 

5from typing import TYPE_CHECKING, Generic 

6 

7from slixmpp import JID 

8from slixmpp.exceptions import IqError, IqTimeout, XMPPError 

9from sqlalchemy.orm import Session 

10from sqlalchemy.orm import Session as OrmSession 

11 

12from ..db.models import Contact, GatewayUser 

13from ..util import SubclassableOnce 

14from ..util.jid_escaping import EscapeMixin 

15from ..util.lock import NamedLockMixin 

16from ..util.types import AnySession, LegacyContactType 

17from ..util.util import timeit 

18 

19if TYPE_CHECKING: 

20 pass 

21 

22 

23class ContactIsUser(Exception): 

24 pass 

25 

26 

27class LegacyRoster( 

28 Generic[LegacyContactType], 

29 NamedLockMixin, 

30 EscapeMixin, 

31 SubclassableOnce, 

32): 

33 """ 

34 Virtual roster of a gateway user that allows to represent all 

35 of their contacts as singleton instances (if used properly and not too bugged). 

36 

37 Every :class:`.BaseSession` instance will have its own :class:`.LegacyRoster` instance 

38 accessible via the :attr:`.BaseSession.contacts` attribute. 

39 

40 Typically, you will mostly use the :meth:`.LegacyRoster.by_legacy_id` function to 

41 retrieve a contact instance. 

42 

43 You might need to override :meth:`.LegacyRoster.legacy_id_to_jid_username` and/or 

44 :meth:`.LegacyRoster.jid_username_to_legacy_id` to incorporate some custom logic 

45 if you need some characters when translation JID user parts and legacy IDs. 

46 """ 

47 

48 _contact_cls: type[LegacyContactType] 

49 

50 def __init__(self, session: AnySession) -> None: 

51 super().__init__() 

52 

53 self.log = logging.getLogger(f"{session.user_jid.bare}:roster") 

54 self.user_legacy_id: str | None = None 

55 self.ready: asyncio.Future[bool] = session.xmpp.loop.create_future() 

56 

57 self.session = session 

58 self.__filling = False 

59 

60 @property 

61 def user(self) -> GatewayUser: 

62 return self.session.user 

63 

64 def orm(self) -> Session: 

65 return self.session.xmpp.store.session() 

66 

67 def from_store(self, stored: Contact) -> LegacyContactType: 

68 return self._contact_cls(self.session, stored=stored) 

69 

70 def __repr__(self) -> str: 

71 return f"<Roster of {self.session.user_jid}>" 

72 

73 def __iter__(self) -> Iterator[LegacyContactType]: 

74 with self.orm() as orm: 

75 for stored in orm.query(Contact).filter_by(user=self.user).all(): 

76 if stored.updated: 

77 yield self.from_store(stored) 

78 

79 def known_contacts(self, only_friends: bool = True) -> dict[str, LegacyContactType]: 

80 if only_friends: 

81 return {c.jid.bare: c for c in self if c.is_friend} 

82 return {c.jid.bare: c for c in self} 

83 

84 async def by_jid(self, contact_jid: JID) -> LegacyContactType: 

85 # """ 

86 # Retrieve a contact by their JID 

87 # 

88 # If the contact was not instantiated before, it will be created 

89 # using :meth:`slidge.LegacyRoster.jid_username_to_legacy_id` to infer their 

90 # legacy user ID. 

91 # 

92 # :param contact_jid: 

93 # :return: 

94 # """ 

95 username = contact_jid.node 

96 if not username: 

97 raise XMPPError( 

98 "bad-request", "Contacts must have a local part in their JID" 

99 ) 

100 contact_jid = JID(contact_jid.bare) 

101 async with self.lock(("username", username)): 

102 legacy_id = await self.jid_username_to_legacy_id(username) 

103 if legacy_id == self.user_legacy_id: 

104 raise ContactIsUser 

105 if self.get_lock(("legacy_id", legacy_id)): 

106 self.log.debug("Already updating %s via by_legacy_id()", contact_jid) 

107 return await self.by_legacy_id(legacy_id) 

108 

109 with self.orm() as orm: 

110 stored = ( 

111 orm.query(Contact) 

112 .filter_by(user=self.user, jid=contact_jid) 

113 .one_or_none() 

114 ) 

115 if stored is None: 

116 stored = Contact( 

117 user_account_id=self.session.user_pk, 

118 legacy_id=legacy_id, 

119 jid=contact_jid, 

120 ) 

121 return await self.__update_if_needed(stored) 

122 

123 async def __update_if_needed(self, stored: Contact) -> LegacyContactType: 

124 contact = self.from_store(stored) 

125 if contact.stored.updated: 

126 return contact 

127 

128 with contact.updating_info(): 

129 await contact.update_info() 

130 if contact.is_friend and not self.__filling: 

131 await contact.add_to_roster() 

132 

133 if contact.cached_presence is not None: 

134 contact._store_last_presence(contact.cached_presence) 

135 return contact 

136 

137 def by_jid_only_if_exists(self, contact_jid: JID) -> LegacyContactType | None: 

138 with self.orm() as orm: 

139 stored = ( 

140 orm.query(Contact) 

141 .filter_by(user=self.user, jid=contact_jid) 

142 .one_or_none() 

143 ) 

144 if stored is not None and stored.updated: 

145 return self.from_store(stored) 

146 return None 

147 

148 @timeit 

149 async def by_legacy_id(self, /, legacy_id: str) -> LegacyContactType: 

150 """ 

151 Retrieve a contact by their legacy_id 

152 

153 If the contact was not instantiated before, it will be created 

154 using :meth:`slidge.LegacyRoster.legacy_id_to_jid_username` to infer their 

155 legacy user ID. 

156 

157 :param legacy_id: 

158 :return: 

159 """ 

160 if legacy_id == self.user_legacy_id: 

161 raise ContactIsUser 

162 async with self.lock(("legacy_id", legacy_id)): 

163 username = await self.legacy_id_to_jid_username(legacy_id) 

164 if self.get_lock(("username", username)): 

165 self.log.debug("Already updating %s via by_jid()", username) 

166 

167 return await self.by_jid( 

168 JID(username + "@" + self.session.xmpp.boundjid.bare) 

169 ) 

170 

171 with self.orm() as orm: 

172 stored = ( 

173 orm.query(Contact) 

174 .filter_by(user=self.user, legacy_id=str(legacy_id)) 

175 .one_or_none() 

176 ) 

177 if stored is None: 

178 stored = Contact( 

179 user_account_id=self.session.user_pk, 

180 legacy_id=str(legacy_id), 

181 jid=JID(f"{username}@{self.session.xmpp.boundjid.bare}"), 

182 ) 

183 return await self.__update_if_needed(stored) 

184 

185 @timeit 

186 async def _fill(self, orm: OrmSession) -> None: 

187 try: 

188 if hasattr(self.session.xmpp, "TEST_MODE"): 

189 # dirty hack to avoid mocking xmpp server replies to this 

190 # during tests 

191 raise PermissionError 

192 iq = await self.session.xmpp["xep_0356"].get_roster( 

193 self.session.user_jid.bare 

194 ) 

195 user_roster = iq["roster"]["items"] 

196 except (PermissionError, IqError, IqTimeout): 

197 user_roster = None 

198 

199 self.__filling = True 

200 async for contact in self.fill(): 

201 if user_roster is None: 

202 continue 

203 item = contact.get_roster_item() 

204 old = user_roster.get(contact.jid.bare) 

205 if old is not None and all( 

206 old[k] == item[contact.jid.bare].get(k) 

207 for k in ("subscription", "groups", "name") 

208 ): 

209 self.log.debug("No need to update roster") 

210 continue 

211 self.log.debug("Updating roster") 

212 if not contact.is_friend: 

213 continue 

214 if not self.session.user.preferences.get("roster_push", True): 

215 continue 

216 try: 

217 await self.session.xmpp["xep_0356"].set_roster( 

218 self.session.user_jid.bare, 

219 item, 

220 ) 

221 except (PermissionError, IqError, IqTimeout) as e: 

222 warnings.warn(f"Could not add to roster: {e}") 

223 else: 

224 contact.added_to_roster = True 

225 contact.send_last_presence(force=True) 

226 orm.commit() 

227 self.__filling = False 

228 

229 async def fill(self) -> AsyncIterator[LegacyContactType]: 

230 """ 

231 Populate slidge's "virtual roster". 

232 

233 This should yield contacts that are meant to be added to the user's 

234 roster, typically by using ``await self.by_legacy_id(contact_id)``. 

235 Setting the contact nicknames, avatar, etc. should be in 

236 :meth:`LegacyContact.update_info()` 

237 

238 It's not mandatory to override this method, but it is recommended way 

239 to populate "friends" of the user. Calling 

240 ``await (await self.by_legacy_id(contact_id)).add_to_roster()`` 

241 accomplishes the same thing, but doing it in here allows to batch 

242 DB queries and is better performance-wise. 

243 

244 """ 

245 return 

246 yield 

247 

248 

249log = logging.getLogger(__name__)