Coverage for slidge/db/store.py: 87%

666 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import json 

5import logging 

6import uuid 

7from contextlib import contextmanager 

8from datetime import datetime, timedelta, timezone 

9from mimetypes import guess_extension 

10from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type 

11 

12from slixmpp import JID, Iq, Message, Presence 

13from slixmpp.exceptions import XMPPError 

14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary 

15from sqlalchemy import Engine, delete, select, update 

16from sqlalchemy.orm import Session, attributes, load_only 

17from sqlalchemy.sql.functions import count 

18 

19from ..core import config 

20from ..util.archive_msg import HistoryMessage 

21from ..util.types import URL, CachedPresence, ClientType 

22from ..util.types import Hat as HatTuple 

23from ..util.types import MamMetadata, MucAffiliation, MucRole, Sticker 

24from .meta import Base 

25from .models import ( 

26 ArchivedMessage, 

27 ArchivedMessageSource, 

28 Attachment, 

29 Avatar, 

30 Bob, 

31 Contact, 

32 ContactSent, 

33 GatewayUser, 

34 Hat, 

35 LegacyIdsMulti, 

36 Participant, 

37 Room, 

38 XmppIdsMulti, 

39 XmppToLegacyEnum, 

40 XmppToLegacyIds, 

41 participant_hats, 

42) 

43 

44if TYPE_CHECKING: 

45 from ..contact.contact import LegacyContact 

46 from ..group.participant import LegacyParticipant 

47 from ..group.room import LegacyMUC 

48 

49 

50class EngineMixin: 

51 def __init__(self, engine: Engine): 

52 self._engine = engine 

53 

54 @contextmanager 

55 def session(self, **session_kwargs) -> Iterator[Session]: 

56 global _session 

57 if _session is not None: 

58 yield _session 

59 return 

60 with Session(self._engine, **session_kwargs) as session: 

61 _session = session 

62 try: 

63 yield session 

64 finally: 

65 _session = None 

66 

67 

68class UpdatedMixin(EngineMixin): 

69 model: Type[Base] = NotImplemented 

70 

71 def __init__(self, *a, **kw): 

72 super().__init__(*a, **kw) 

73 with self.session() as session: 

74 session.execute(update(self.model).values(updated=False)) # type:ignore 

75 session.commit() 

76 

77 def get_by_pk(self, pk: int) -> Optional[Base]: 

78 with self.session() as session: 

79 return session.execute( 

80 select(self.model).where(self.model.id == pk) # type:ignore 

81 ).scalar() 

82 

83 

84class SlidgeStore(EngineMixin): 

85 def __init__(self, engine: Engine): 

86 super().__init__(engine) 

87 self.users = UserStore(engine) 

88 self.avatars = AvatarStore(engine) 

89 self.contacts = ContactStore(engine) 

90 self.mam = MAMStore(engine) 

91 self.multi = MultiStore(engine) 

92 self.attachments = AttachmentStore(engine) 

93 self.rooms = RoomStore(engine) 

94 self.sent = SentStore(engine) 

95 self.participants = ParticipantStore(engine) 

96 self.bob = BobStore(engine) 

97 

98 

99class UserStore(EngineMixin): 

100 def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser: 

101 if jid.resource: 

102 jid = JID(jid.bare) 

103 with self.session(expire_on_commit=False) as session: 

104 user = session.execute( 

105 select(GatewayUser).where(GatewayUser.jid == jid) 

106 ).scalar() 

107 if user is not None: 

108 return user 

109 user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data) 

110 session.add(user) 

111 session.commit() 

112 return user 

113 

114 def update(self, user: GatewayUser): 

115 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473 

116 attributes.flag_modified(user, "legacy_module_data") 

117 attributes.flag_modified(user, "preferences") 

118 with self.session() as session: 

119 session.add(user) 

120 session.commit() 

121 

122 def get_all(self) -> Iterator[GatewayUser]: 

123 with self.session() as session: 

124 yield from session.execute(select(GatewayUser)).scalars() 

125 

126 def get(self, jid: JID) -> Optional[GatewayUser]: 

127 with self.session() as session: 

128 return session.execute( 

129 select(GatewayUser).where(GatewayUser.jid == jid.bare) 

130 ).scalar() 

131 

132 def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]: 

133 return self.get(stanza.get_from()) 

134 

135 def delete(self, jid: JID) -> None: 

136 with self.session() as session: 

137 session.delete(self.get(jid)) 

138 session.commit() 

139 

140 def set_avatar_hash(self, pk: int, h: str | None = None) -> None: 

141 with self.session() as session: 

142 session.execute( 

143 update(GatewayUser).where(GatewayUser.id == pk).values(avatar_hash=h) 

144 ) 

145 session.commit() 

146 

147 

148class AvatarStore(EngineMixin): 

149 def get_by_url(self, url: URL | str) -> Optional[Avatar]: 

150 with self.session() as session: 

151 return session.execute(select(Avatar).where(Avatar.url == url)).scalar() 

152 

153 def get_by_pk(self, pk: int) -> Optional[Avatar]: 

154 with self.session() as session: 

155 return session.execute(select(Avatar).where(Avatar.id == pk)).scalar() 

156 

157 def delete_by_pk(self, pk: int): 

158 with self.session() as session: 

159 session.execute(delete(Avatar).where(Avatar.id == pk)) 

160 session.commit() 

161 

162 def get_all(self) -> Iterator[Avatar]: 

163 with self.session() as session: 

164 yield from session.execute(select(Avatar)).scalars() 

165 

166 

167class SentStore(EngineMixin): 

168 def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None: 

169 with self.session() as session: 

170 msg = ( 

171 session.query(XmppToLegacyIds) 

172 .filter(XmppToLegacyIds.user_account_id == user_pk) 

173 .filter(XmppToLegacyIds.legacy_id == legacy_id) 

174 .filter(XmppToLegacyIds.xmpp_id == xmpp_id) 

175 .scalar() 

176 ) 

177 if msg is None: 

178 msg = XmppToLegacyIds(user_account_id=user_pk) 

179 else: 

180 log.debug("Resetting a DM from sent store") 

181 msg.legacy_id = legacy_id 

182 msg.xmpp_id = xmpp_id 

183 msg.type = XmppToLegacyEnum.DM 

184 session.add(msg) 

185 session.commit() 

186 

187 def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]: 

188 with self.session() as session: 

189 return session.execute( 

190 select(XmppToLegacyIds.xmpp_id) 

191 .where(XmppToLegacyIds.user_account_id == user_pk) 

192 .where(XmppToLegacyIds.legacy_id == legacy_id) 

193 .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM) 

194 ).scalar() 

195 

196 def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]: 

197 with self.session() as session: 

198 return session.execute( 

199 select(XmppToLegacyIds.legacy_id) 

200 .where(XmppToLegacyIds.user_account_id == user_pk) 

201 .where(XmppToLegacyIds.xmpp_id == xmpp_id) 

202 .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM) 

203 ).scalar() 

204 

205 def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None: 

206 with self.session() as session: 

207 msg = XmppToLegacyIds( 

208 user_account_id=user_pk, 

209 legacy_id=legacy_id, 

210 xmpp_id=xmpp_id, 

211 type=XmppToLegacyEnum.GROUP_CHAT, 

212 ) 

213 session.add(msg) 

214 session.commit() 

215 

216 def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]: 

217 with self.session() as session: 

218 return session.execute( 

219 select(XmppToLegacyIds.xmpp_id) 

220 .where(XmppToLegacyIds.user_account_id == user_pk) 

221 .where(XmppToLegacyIds.legacy_id == legacy_id) 

222 .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT) 

223 ).scalar() 

224 

225 def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]: 

226 with self.session() as session: 

227 return session.execute( 

228 select(XmppToLegacyIds.legacy_id) 

229 .where(XmppToLegacyIds.user_account_id == user_pk) 

230 .where(XmppToLegacyIds.xmpp_id == xmpp_id) 

231 .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT) 

232 ).scalar() 

233 

234 def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None: 

235 with self.session() as session: 

236 msg = XmppToLegacyIds( 

237 user_account_id=user_pk, 

238 legacy_id=legacy_id, 

239 xmpp_id=xmpp_id, 

240 type=XmppToLegacyEnum.THREAD, 

241 ) 

242 session.add(msg) 

243 session.commit() 

244 

245 def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]: 

246 with self.session() as session: 

247 return session.execute( 

248 select(XmppToLegacyIds.legacy_id) 

249 .where(XmppToLegacyIds.user_account_id == user_pk) 

250 .where(XmppToLegacyIds.xmpp_id == xmpp_id) 

251 .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD) 

252 ).scalar() 

253 

254 def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool: 

255 with self.session() as session: 

256 return ( 

257 session.execute( 

258 select(XmppToLegacyIds.legacy_id) 

259 .where(XmppToLegacyIds.user_account_id == user_pk) 

260 .where(XmppToLegacyIds.legacy_id == legacy_id) 

261 ).scalar() 

262 is not None 

263 ) 

264 

265 

266class ContactStore(UpdatedMixin): 

267 model = Contact 

268 

269 def __init__(self, *a, **k): 

270 super().__init__(*a, **k) 

271 with self.session() as session: 

272 session.execute(update(Contact).values(cached_presence=False)) 

273 session.commit() 

274 

275 def get_all(self, user_pk: int) -> Iterator[Contact]: 

276 with self.session() as session: 

277 yield from session.execute( 

278 select(Contact).where(Contact.user_account_id == user_pk) 

279 ).scalars() 

280 

281 def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]: 

282 with self.session() as session: 

283 return session.execute( 

284 select(Contact) 

285 .where(Contact.jid == jid.bare) 

286 .where(Contact.user_account_id == user_pk) 

287 ).scalar() 

288 

289 def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]: 

290 with self.session() as session: 

291 return session.execute( 

292 select(Contact) 

293 .where(Contact.legacy_id == legacy_id) 

294 .where(Contact.user_account_id == user_pk) 

295 ).scalar() 

296 

297 def update_nick(self, contact_pk: int, nick: Optional[str]) -> None: 

298 with self.session() as session: 

299 session.execute( 

300 update(Contact).where(Contact.id == contact_pk).values(nick=nick) 

301 ) 

302 session.commit() 

303 

304 def get_presence(self, contact_pk: int) -> Optional[CachedPresence]: 

305 with self.session() as session: 

306 presence = session.execute( 

307 select( 

308 Contact.last_seen, 

309 Contact.ptype, 

310 Contact.pstatus, 

311 Contact.pshow, 

312 Contact.cached_presence, 

313 ).where(Contact.id == contact_pk) 

314 ).first() 

315 if presence is None or not presence[-1]: 

316 return None 

317 return CachedPresence(*presence[:-1]) 

318 

319 def set_presence(self, contact_pk: int, presence: CachedPresence) -> None: 

320 with self.session() as session: 

321 session.execute( 

322 update(Contact) 

323 .where(Contact.id == contact_pk) 

324 .values(**presence._asdict(), cached_presence=True) 

325 ) 

326 session.commit() 

327 

328 def reset_presence(self, contact_pk: int): 

329 with self.session() as session: 

330 session.execute( 

331 update(Contact) 

332 .where(Contact.id == contact_pk) 

333 .values( 

334 last_seen=None, 

335 ptype=None, 

336 pstatus=None, 

337 pshow=None, 

338 cached_presence=False, 

339 ) 

340 ) 

341 session.commit() 

342 

343 def set_avatar( 

344 self, contact_pk: int, avatar_pk: Optional[int], avatar_legacy_id: Optional[str] 

345 ): 

346 with self.session() as session: 

347 session.execute( 

348 update(Contact) 

349 .where(Contact.id == contact_pk) 

350 .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id) 

351 ) 

352 session.commit() 

353 

354 def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]: 

355 with self.session() as session: 

356 contact = session.execute( 

357 select(Contact).where(Contact.id == contact_pk) 

358 ).scalar() 

359 if contact is None or contact.avatar is None: 

360 return None 

361 return contact.avatar_legacy_id 

362 

363 def update(self, contact: "LegacyContact", commit=True) -> int: 

364 with self.session() as session: 

365 if contact.contact_pk is None: 

366 if contact.cached_presence is not None: 

367 presence_kwargs = contact.cached_presence._asdict() 

368 presence_kwargs["cached_presence"] = True 

369 else: 

370 presence_kwargs = {} 

371 row = Contact( 

372 jid=contact.jid.bare, 

373 legacy_id=str(contact.legacy_id), 

374 user_account_id=contact.user_pk, 

375 **presence_kwargs, 

376 ) 

377 else: 

378 row = ( 

379 session.query(Contact) 

380 .filter(Contact.id == contact.contact_pk) 

381 .one() 

382 ) 

383 row.nick = contact.name 

384 row.is_friend = contact.is_friend 

385 row.added_to_roster = contact.added_to_roster 

386 row.updated = True 

387 row.extra_attributes = contact.serialize_extra_attributes() 

388 row.caps_ver = contact._caps_ver 

389 row.vcard = contact._vcard 

390 row.vcard_fetched = contact._vcard_fetched 

391 row.client_type = contact.client_type 

392 session.add(row) 

393 if commit: 

394 session.commit() 

395 return row.id 

396 

397 def set_vcard(self, contact_pk: int, vcard: str | None) -> None: 

398 with self.session() as session: 

399 session.execute( 

400 update(Contact) 

401 .where(Contact.id == contact_pk) 

402 .values(vcard=vcard, vcard_fetched=True) 

403 ) 

404 session.commit() 

405 

406 def add_to_sent(self, contact_pk: int, msg_id: str) -> None: 

407 with self.session() as session: 

408 if ( 

409 session.query(ContactSent.id) 

410 .where(ContactSent.contact_id == contact_pk) 

411 .where(ContactSent.msg_id == msg_id) 

412 .first() 

413 ) is not None: 

414 log.warning( 

415 "Contact %s has already sent message %s", contact_pk, msg_id 

416 ) 

417 return 

418 new = ContactSent(contact_id=contact_pk, msg_id=msg_id) 

419 session.add(new) 

420 session.commit() 

421 

422 def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]: 

423 result = [] 

424 to_del = [] 

425 with self.session() as session: 

426 for row in session.execute( 

427 select(ContactSent) 

428 .where(ContactSent.contact_id == contact_pk) 

429 .order_by(ContactSent.id) 

430 ).scalars(): 

431 to_del.append(row.id) 

432 result.append(row.msg_id) 

433 if row.msg_id == msg_id: 

434 break 

435 for row_id in to_del: 

436 session.execute(delete(ContactSent).where(ContactSent.id == row_id)) 

437 return result 

438 

439 def set_friend(self, contact_pk: int, is_friend: bool) -> None: 

440 with self.session() as session: 

441 session.execute( 

442 update(Contact) 

443 .where(Contact.id == contact_pk) 

444 .values(is_friend=is_friend) 

445 ) 

446 session.commit() 

447 

448 def set_added_to_roster(self, contact_pk: int, value: bool) -> None: 

449 with self.session() as session: 

450 session.execute( 

451 update(Contact) 

452 .where(Contact.id == contact_pk) 

453 .values(added_to_roster=value) 

454 ) 

455 session.commit() 

456 

457 def delete(self, contact_pk: int) -> None: 

458 with self.session() as session: 

459 session.execute(delete(Contact).where(Contact.id == contact_pk)) 

460 session.commit() 

461 

462 def set_client_type(self, contact_pk: int, value: ClientType): 

463 with self.session() as session: 

464 session.execute( 

465 update(Contact) 

466 .where(Contact.id == contact_pk) 

467 .values(client_type=value) 

468 ) 

469 session.commit() 

470 

471 

472class MAMStore(EngineMixin): 

473 def __init__(self, *a, **kw): 

474 super().__init__(*a, **kw) 

475 with self.session() as session: 

476 session.execute( 

477 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL) 

478 ) 

479 session.commit() 

480 

481 def nuke_older_than(self, days: int) -> None: 

482 with self.session() as session: 

483 session.execute( 

484 delete(ArchivedMessage).where( 

485 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days) 

486 ) 

487 ) 

488 session.commit() 

489 

490 def add_message( 

491 self, 

492 room_pk: int, 

493 message: HistoryMessage, 

494 archive_only: bool, 

495 legacy_msg_id: str | None, 

496 ) -> None: 

497 with self.session() as session: 

498 source = ( 

499 ArchivedMessageSource.BACKFILL 

500 if archive_only 

501 else ArchivedMessageSource.LIVE 

502 ) 

503 existing = session.execute( 

504 select(ArchivedMessage) 

505 .where(ArchivedMessage.room_id == room_pk) 

506 .where(ArchivedMessage.stanza_id == message.id) 

507 ).scalar() 

508 if existing is None and legacy_msg_id is not None: 

509 existing = session.execute( 

510 select(ArchivedMessage) 

511 .where(ArchivedMessage.room_id == room_pk) 

512 .where(ArchivedMessage.legacy_id == legacy_msg_id) 

513 ).scalar() 

514 if existing is not None: 

515 log.debug("Updating message %s in room %s", message.id, room_pk) 

516 existing.timestamp = message.when 

517 existing.stanza = str(message.stanza) 

518 existing.author_jid = message.stanza.get_from() 

519 existing.source = source 

520 existing.legacy_id = legacy_msg_id 

521 session.add(existing) 

522 session.commit() 

523 return 

524 mam_msg = ArchivedMessage( 

525 stanza_id=message.id, 

526 timestamp=message.when, 

527 stanza=str(message.stanza), 

528 author_jid=message.stanza.get_from(), 

529 room_id=room_pk, 

530 source=source, 

531 legacy_id=legacy_msg_id, 

532 ) 

533 session.add(mam_msg) 

534 session.commit() 

535 

536 def get_messages( 

537 self, 

538 room_pk: int, 

539 start_date: Optional[datetime] = None, 

540 end_date: Optional[datetime] = None, 

541 before_id: Optional[str] = None, 

542 after_id: Optional[str] = None, 

543 ids: Collection[str] = (), 

544 last_page_n: Optional[int] = None, 

545 sender: Optional[str] = None, 

546 flip=False, 

547 ) -> Iterator[HistoryMessage]: 

548 

549 with self.session() as session: 

550 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

551 if start_date is not None: 

552 q = q.where(ArchivedMessage.timestamp >= start_date) 

553 if end_date is not None: 

554 q = q.where(ArchivedMessage.timestamp <= end_date) 

555 if before_id is not None: 

556 stamp = session.execute( 

557 select(ArchivedMessage.timestamp).where( 

558 ArchivedMessage.stanza_id == before_id 

559 ) 

560 ).scalar() 

561 if stamp is None: 

562 raise XMPPError( 

563 "item-not-found", 

564 f"Message {before_id} not found", 

565 ) 

566 q = q.where(ArchivedMessage.timestamp < stamp) 

567 if after_id is not None: 

568 stamp = session.execute( 

569 select(ArchivedMessage.timestamp).where( 

570 ArchivedMessage.stanza_id == after_id 

571 ) 

572 ).scalar() 

573 if stamp is None: 

574 raise XMPPError( 

575 "item-not-found", 

576 f"Message {after_id} not found", 

577 ) 

578 q = q.where(ArchivedMessage.timestamp > stamp) 

579 if ids: 

580 q = q.filter(ArchivedMessage.stanza_id.in_(ids)) 

581 if sender is not None: 

582 q = q.where(ArchivedMessage.author_jid == sender) 

583 if flip: 

584 q = q.order_by(ArchivedMessage.timestamp.desc()) 

585 else: 

586 q = q.order_by(ArchivedMessage.timestamp.asc()) 

587 msgs = list(session.execute(q).scalars()) 

588 if ids and len(msgs) != len(ids): 

589 raise XMPPError( 

590 "item-not-found", 

591 "One of the requested messages IDs could not be found " 

592 "with the given constraints.", 

593 ) 

594 if last_page_n is not None: 

595 if flip: 

596 msgs = msgs[:last_page_n] 

597 else: 

598 msgs = msgs[-last_page_n:] 

599 for h in msgs: 

600 yield HistoryMessage( 

601 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc) 

602 ) 

603 

604 def get_first(self, room_pk: int, with_legacy_id=False) -> ArchivedMessage | None: 

605 with self.session() as session: 

606 q = ( 

607 select(ArchivedMessage) 

608 .where(ArchivedMessage.room_id == room_pk) 

609 .order_by(ArchivedMessage.timestamp.asc()) 

610 ) 

611 if with_legacy_id: 

612 q = q.filter(ArchivedMessage.legacy_id.isnot(None)) 

613 return session.execute(q).scalar() 

614 

615 def get_last( 

616 self, room_pk: int, source: ArchivedMessageSource | None = None 

617 ) -> ArchivedMessage | None: 

618 with self.session() as session: 

619 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

620 

621 if source is not None: 

622 q = q.where(ArchivedMessage.source == source) 

623 

624 return session.execute( 

625 q.order_by(ArchivedMessage.timestamp.desc()) 

626 ).scalar() 

627 

628 def get_first_and_last(self, room_pk: int) -> list[MamMetadata]: 

629 r = [] 

630 with self.session(): 

631 first = self.get_first(room_pk) 

632 if first is not None: 

633 r.append(MamMetadata(first.stanza_id, first.timestamp)) 

634 last = self.get_last(room_pk) 

635 if last is not None: 

636 r.append(MamMetadata(last.stanza_id, last.timestamp)) 

637 return r 

638 

639 def get_most_recent_with_legacy_id( 

640 self, room_pk: int, source: ArchivedMessageSource | None = None 

641 ) -> ArchivedMessage | None: 

642 with self.session() as session: 

643 q = ( 

644 select(ArchivedMessage) 

645 .where(ArchivedMessage.room_id == room_pk) 

646 .where(ArchivedMessage.legacy_id.isnot(None)) 

647 ) 

648 if source is not None: 

649 q = q.where(ArchivedMessage.source == source) 

650 return session.execute( 

651 q.order_by(ArchivedMessage.timestamp.desc()) 

652 ).scalar() 

653 

654 def get_least_recent_with_legacy_id_after( 

655 self, room_pk: int, after_id: str, source=ArchivedMessageSource.LIVE 

656 ) -> ArchivedMessage | None: 

657 with self.session() as session: 

658 after_timestamp = ( 

659 session.query(ArchivedMessage.timestamp) 

660 .filter(ArchivedMessage.room_id == room_pk) 

661 .filter(ArchivedMessage.legacy_id == after_id) 

662 .scalar() 

663 ) 

664 q = ( 

665 select(ArchivedMessage) 

666 .where(ArchivedMessage.room_id == room_pk) 

667 .where(ArchivedMessage.legacy_id.isnot(None)) 

668 .where(ArchivedMessage.source == source) 

669 .where(ArchivedMessage.timestamp > after_timestamp) 

670 ) 

671 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar() 

672 

673 def get_by_legacy_id(self, room_pk: int, legacy_id: str) -> ArchivedMessage | None: 

674 with self.session() as session: 

675 return ( 

676 session.query(ArchivedMessage) 

677 .filter(ArchivedMessage.room_id == room_pk) 

678 .filter(ArchivedMessage.legacy_id == legacy_id) 

679 .first() 

680 ) 

681 

682 

683class MultiStore(EngineMixin): 

684 def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]: 

685 with self.session() as session: 

686 multi = session.execute( 

687 select(XmppIdsMulti) 

688 .where(XmppIdsMulti.xmpp_id == xmpp_id) 

689 .where(XmppIdsMulti.user_account_id == user_pk) 

690 ).scalar() 

691 if multi is None: 

692 return [] 

693 return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids] 

694 

695 def set_xmpp_ids( 

696 self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False 

697 ) -> None: 

698 with self.session() as session: 

699 existing = session.execute( 

700 select(LegacyIdsMulti) 

701 .where(LegacyIdsMulti.user_account_id == user_pk) 

702 .where(LegacyIdsMulti.legacy_id == legacy_msg_id) 

703 ).scalar() 

704 if existing is not None: 

705 if fail: 

706 raise 

707 log.debug("Resetting multi for %s", legacy_msg_id) 

708 session.execute( 

709 delete(LegacyIdsMulti) 

710 .where(LegacyIdsMulti.user_account_id == user_pk) 

711 .where(LegacyIdsMulti.legacy_id == legacy_msg_id) 

712 ) 

713 for i in xmpp_ids: 

714 session.execute( 

715 delete(XmppIdsMulti) 

716 .where(XmppIdsMulti.user_account_id == user_pk) 

717 .where(XmppIdsMulti.xmpp_id == i) 

718 ) 

719 session.commit() 

720 self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True) 

721 return 

722 

723 row = LegacyIdsMulti( 

724 user_account_id=user_pk, 

725 legacy_id=legacy_msg_id, 

726 xmpp_ids=[ 

727 XmppIdsMulti(user_account_id=user_pk, xmpp_id=i) 

728 for i in xmpp_ids 

729 if i 

730 ], 

731 ) 

732 session.add(row) 

733 session.commit() 

734 

735 def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]: 

736 with self.session() as session: 

737 multi = session.execute( 

738 select(XmppIdsMulti) 

739 .where(XmppIdsMulti.xmpp_id == xmpp_id) 

740 .where(XmppIdsMulti.user_account_id == user_pk) 

741 ).scalar() 

742 if multi is None: 

743 return None 

744 return multi.legacy_ids_multi.legacy_id 

745 

746 

747class AttachmentStore(EngineMixin): 

748 def get_url(self, legacy_file_id: str) -> Optional[str]: 

749 with self.session() as session: 

750 return session.execute( 

751 select(Attachment.url).where( 

752 Attachment.legacy_file_id == legacy_file_id 

753 ) 

754 ).scalar() 

755 

756 def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None: 

757 with self.session() as session: 

758 att = session.execute( 

759 select(Attachment) 

760 .where(Attachment.legacy_file_id == legacy_file_id) 

761 .where(Attachment.user_account_id == user_pk) 

762 ).scalar() 

763 if att is None: 

764 att = Attachment( 

765 legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk 

766 ) 

767 session.add(att) 

768 else: 

769 att.url = url 

770 session.commit() 

771 

772 def get_sims(self, url: str) -> Optional[str]: 

773 with self.session() as session: 

774 return session.execute( 

775 select(Attachment.sims).where(Attachment.url == url) 

776 ).scalar() 

777 

778 def set_sims(self, url: str, sims: str) -> None: 

779 with self.session() as session: 

780 session.execute( 

781 update(Attachment).where(Attachment.url == url).values(sims=sims) 

782 ) 

783 session.commit() 

784 

785 def get_sfs(self, url: str) -> Optional[str]: 

786 with self.session() as session: 

787 return session.execute( 

788 select(Attachment.sfs).where(Attachment.url == url) 

789 ).scalar() 

790 

791 def set_sfs(self, url: str, sfs: str) -> None: 

792 with self.session() as session: 

793 session.execute( 

794 update(Attachment).where(Attachment.url == url).values(sfs=sfs) 

795 ) 

796 session.commit() 

797 

798 def remove(self, legacy_file_id: str) -> None: 

799 with self.session() as session: 

800 session.execute( 

801 delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id) 

802 ) 

803 session.commit() 

804 

805 

806class RoomStore(UpdatedMixin): 

807 model = Room 

808 

809 def __init__(self, *a, **kw): 

810 super().__init__(*a, **kw) 

811 with self.session() as session: 

812 session.execute( 

813 update(Room).values( 

814 subject_setter=None, 

815 user_resources=None, 

816 history_filled=False, 

817 participants_filled=False, 

818 ) 

819 ) 

820 session.commit() 

821 

822 def set_avatar( 

823 self, room_pk: int, avatar_pk: int | None, avatar_legacy_id: str | None 

824 ) -> None: 

825 with self.session() as session: 

826 session.execute( 

827 update(Room) 

828 .where(Room.id == room_pk) 

829 .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id) 

830 ) 

831 session.commit() 

832 

833 def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]: 

834 with self.session() as session: 

835 room = session.execute(select(Room).where(Room.id == room_pk)).scalar() 

836 if room is None or room.avatar is None: 

837 return None 

838 return room.avatar_legacy_id 

839 

840 def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]: 

841 if jid.resource: 

842 raise TypeError 

843 with self.session() as session: 

844 return session.execute( 

845 select(Room) 

846 .where(Room.user_account_id == user_pk) 

847 .where(Room.jid == jid) 

848 ).scalar() 

849 

850 def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]: 

851 with self.session() as session: 

852 return session.execute( 

853 select(Room) 

854 .where(Room.user_account_id == user_pk) 

855 .where(Room.legacy_id == legacy_id) 

856 ).scalar() 

857 

858 def update_subject_setter(self, room_pk: int, subject_setter: str | None): 

859 with self.session() as session: 

860 session.execute( 

861 update(Room) 

862 .where(Room.id == room_pk) 

863 .values(subject_setter=subject_setter) 

864 ) 

865 session.commit() 

866 

867 def update(self, muc: "LegacyMUC") -> int: 

868 with self.session() as session: 

869 if muc.pk is None: 

870 row = Room( 

871 jid=muc.jid, 

872 legacy_id=str(muc.legacy_id), 

873 user_account_id=muc.user_pk, 

874 ) 

875 else: 

876 row = session.query(Room).filter(Room.id == muc.pk).one() 

877 

878 row.updated = True 

879 row.extra_attributes = muc.serialize_extra_attributes() 

880 row.name = muc.name 

881 row.description = muc.description 

882 row.user_resources = ( 

883 None 

884 if not muc._user_resources 

885 else json.dumps(list(muc._user_resources)) 

886 ) 

887 row.muc_type = muc.type 

888 row.subject = muc.subject 

889 row.subject_date = muc.subject_date 

890 row.subject_setter = muc.subject_setter 

891 row.participants_filled = muc._participants_filled 

892 row.n_participants = muc._n_participants 

893 row.user_nick = muc.user_nick 

894 session.add(row) 

895 session.commit() 

896 return row.id 

897 

898 def update_subject_date( 

899 self, room_pk: int, subject_date: Optional[datetime] 

900 ) -> None: 

901 with self.session() as session: 

902 session.execute( 

903 update(Room).where(Room.id == room_pk).values(subject_date=subject_date) 

904 ) 

905 session.commit() 

906 

907 def update_subject(self, room_pk: int, subject: Optional[str]) -> None: 

908 with self.session() as session: 

909 session.execute( 

910 update(Room).where(Room.id == room_pk).values(subject=subject) 

911 ) 

912 session.commit() 

913 

914 def update_description(self, room_pk: int, desc: Optional[str]) -> None: 

915 with self.session() as session: 

916 session.execute( 

917 update(Room).where(Room.id == room_pk).values(description=desc) 

918 ) 

919 session.commit() 

920 

921 def update_name(self, room_pk: int, name: Optional[str]) -> None: 

922 with self.session() as session: 

923 session.execute(update(Room).where(Room.id == room_pk).values(name=name)) 

924 session.commit() 

925 

926 def update_n_participants(self, room_pk: int, n: Optional[int]) -> None: 

927 with self.session() as session: 

928 session.execute( 

929 update(Room).where(Room.id == room_pk).values(n_participants=n) 

930 ) 

931 session.commit() 

932 

933 def update_user_nick(self, room_pk, nick: str) -> None: 

934 with self.session() as session: 

935 session.execute( 

936 update(Room).where(Room.id == room_pk).values(user_nick=nick) 

937 ) 

938 session.commit() 

939 

940 def delete(self, room_pk: int) -> None: 

941 with self.session() as session: 

942 session.execute(delete(Room).where(Room.id == room_pk)) 

943 session.execute(delete(Participant).where(Participant.room_id == room_pk)) 

944 session.commit() 

945 

946 def set_resource(self, room_pk: int, resources: set[str]) -> None: 

947 with self.session() as session: 

948 session.execute( 

949 update(Room) 

950 .where(Room.id == room_pk) 

951 .values( 

952 user_resources=( 

953 None if not resources else json.dumps(list(resources)) 

954 ) 

955 ) 

956 ) 

957 session.commit() 

958 

959 def nickname_is_available(self, room_pk: int, nickname: str) -> bool: 

960 with self.session() as session: 

961 return ( 

962 session.execute( 

963 select(Participant) 

964 .where(Participant.room_id == room_pk) 

965 .where(Participant.nickname == nickname) 

966 ).scalar() 

967 is None 

968 ) 

969 

970 def set_participants_filled(self, room_pk: int, val=True) -> None: 

971 with self.session() as session: 

972 session.execute( 

973 update(Room).where(Room.id == room_pk).values(participants_filled=val) 

974 ) 

975 session.commit() 

976 

977 def set_history_filled(self, room_pk: int, val=True) -> None: 

978 with self.session() as session: 

979 session.execute( 

980 update(Room).where(Room.id == room_pk).values(history_filled=True) 

981 ) 

982 session.commit() 

983 

984 def get_all(self, user_pk: int) -> Iterator[Room]: 

985 with self.session() as session: 

986 yield from session.execute( 

987 select(Room).where(Room.user_account_id == user_pk) 

988 ).scalars() 

989 

990 def get_all_jid_and_names(self, user_pk: int) -> Iterator[Room]: 

991 with self.session() as session: 

992 yield from session.scalars( 

993 select(Room) 

994 .filter(Room.user_account_id == user_pk) 

995 .options(load_only(Room.jid, Room.name)) 

996 .order_by(Room.name) 

997 ).all() 

998 

999 

1000class ParticipantStore(EngineMixin): 

1001 def __init__(self, *a, **kw): 

1002 super().__init__(*a, **kw) 

1003 with self.session() as session: 

1004 session.execute(delete(participant_hats)) 

1005 session.execute(delete(Hat)) 

1006 session.execute(delete(Participant)) 

1007 session.commit() 

1008 

1009 def add(self, room_pk: int, nickname: str) -> int: 

1010 with self.session() as session: 

1011 existing = session.execute( 

1012 select(Participant.id) 

1013 .where(Participant.room_id == room_pk) 

1014 .where(Participant.nickname == nickname) 

1015 ).scalar() 

1016 if existing is not None: 

1017 return existing 

1018 participant = Participant(room_id=room_pk, nickname=nickname) 

1019 session.add(participant) 

1020 session.commit() 

1021 return participant.id 

1022 

1023 def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]: 

1024 with self.session() as session: 

1025 return session.execute( 

1026 select(Participant) 

1027 .where(Participant.room_id == room_pk) 

1028 .where(Participant.nickname == nickname) 

1029 ).scalar() 

1030 

1031 def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]: 

1032 with self.session() as session: 

1033 return session.execute( 

1034 select(Participant) 

1035 .where(Participant.room_id == room_pk) 

1036 .where(Participant.resource == resource) 

1037 ).scalar() 

1038 

1039 def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]: 

1040 with self.session() as session: 

1041 return session.execute( 

1042 select(Participant) 

1043 .where(Participant.room_id == room_pk) 

1044 .where(Participant.contact_id == contact_pk) 

1045 ).scalar() 

1046 

1047 def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]: 

1048 with self.session() as session: 

1049 q = select(Participant).where(Participant.room_id == room_pk) 

1050 if not user_included: 

1051 q = q.where(~Participant.is_user) 

1052 yield from session.execute(q).scalars() 

1053 

1054 def get_for_contact(self, contact_pk: int) -> Iterator[Participant]: 

1055 with self.session() as session: 

1056 yield from session.execute( 

1057 select(Participant).where(Participant.contact_id == contact_pk) 

1058 ).scalars() 

1059 

1060 def update(self, participant: "LegacyParticipant") -> None: 

1061 with self.session() as session: 

1062 session.execute( 

1063 update(Participant) 

1064 .where(Participant.id == participant.pk) 

1065 .values( 

1066 resource=participant.jid.resource, 

1067 affiliation=participant.affiliation, 

1068 role=participant.role, 

1069 presence_sent=participant._presence_sent, # type:ignore 

1070 # hats=[self.add_hat(h.uri, h.title) for h in participant._hats], 

1071 is_user=participant.is_user, 

1072 contact_id=( 

1073 None 

1074 if participant.contact is None 

1075 else participant.contact.contact_pk 

1076 ), 

1077 ) 

1078 ) 

1079 session.commit() 

1080 

1081 def add_hat(self, uri: str, title: str) -> Hat: 

1082 with self.session() as session: 

1083 existing = session.execute( 

1084 select(Hat).where(Hat.uri == uri).where(Hat.title == title) 

1085 ).scalar() 

1086 if existing is not None: 

1087 return existing 

1088 hat = Hat(uri=uri, title=title) 

1089 session.add(hat) 

1090 session.commit() 

1091 return hat 

1092 

1093 def set_presence_sent(self, participant_pk: int) -> None: 

1094 with self.session() as session: 

1095 session.execute( 

1096 update(Participant) 

1097 .where(Participant.id == participant_pk) 

1098 .values(presence_sent=True) 

1099 ) 

1100 session.commit() 

1101 

1102 def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None: 

1103 with self.session() as session: 

1104 session.execute( 

1105 update(Participant) 

1106 .where(Participant.id == participant_pk) 

1107 .values(affiliation=affiliation) 

1108 ) 

1109 session.commit() 

1110 

1111 def set_role(self, participant_pk: int, role: MucRole) -> None: 

1112 with self.session() as session: 

1113 session.execute( 

1114 update(Participant) 

1115 .where(Participant.id == participant_pk) 

1116 .values(role=role) 

1117 ) 

1118 session.commit() 

1119 

1120 def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None: 

1121 with self.session() as session: 

1122 part = session.execute( 

1123 select(Participant).where(Participant.id == participant_pk) 

1124 ).scalar() 

1125 if part is None: 

1126 raise ValueError 

1127 part.hats.clear() 

1128 for h in hats: 

1129 hat = self.add_hat(*h) 

1130 if hat in part.hats: 

1131 continue 

1132 part.hats.append(hat) 

1133 session.commit() 

1134 

1135 def delete(self, participant_pk: int) -> None: 

1136 with self.session() as session: 

1137 session.execute(delete(Participant).where(Participant.id == participant_pk)) 

1138 

1139 def get_count(self, room_pk: int) -> int: 

1140 with self.session() as session: 

1141 return session.query( 

1142 count(Participant.id).filter(Participant.room_id == room_pk) 

1143 ).scalar() 

1144 

1145 

1146class BobStore(EngineMixin): 

1147 _ATTR_MAP = { 

1148 "sha-1": "sha_1", 

1149 "sha1": "sha_1", 

1150 "sha-256": "sha_256", 

1151 "sha256": "sha_256", 

1152 "sha-512": "sha_512", 

1153 "sha512": "sha_512", 

1154 } 

1155 

1156 _ALG_MAP = { 

1157 "sha_1": hashlib.sha1, 

1158 "sha_256": hashlib.sha256, 

1159 "sha_512": hashlib.sha512, 

1160 } 

1161 

1162 def __init__(self, *a, **k): 

1163 super().__init__(*a, **k) 

1164 self.root_dir = config.HOME_DIR / "slidge_stickers" 

1165 self.root_dir.mkdir(exist_ok=True) 

1166 

1167 @staticmethod 

1168 def __split_cid(cid: str) -> list[str]: 

1169 return cid.removesuffix("@bob.xmpp.org").split("+") 

1170 

1171 def __get_condition(self, cid: str): 

1172 alg_name, digest = self.__split_cid(cid) 

1173 attr = self._ATTR_MAP.get(alg_name) 

1174 if attr is None: 

1175 log.warning("Unknown hash algo: %s", alg_name) 

1176 return None 

1177 return getattr(Bob, attr) == digest 

1178 

1179 def get(self, cid: str) -> Bob | None: 

1180 with self.session() as session: 

1181 try: 

1182 return session.query(Bob).filter(self.__get_condition(cid)).scalar() 

1183 except ValueError: 

1184 log.warning("Cannot get Bob with CID: %s", cid) 

1185 return None 

1186 

1187 def get_sticker(self, cid: str) -> Sticker | None: 

1188 bob = self.get(cid) 

1189 if bob is None: 

1190 return None 

1191 return Sticker( 

1192 self.root_dir / bob.file_name, 

1193 bob.content_type, 

1194 {h: getattr(bob, h) for h in self._ALG_MAP}, 

1195 ) 

1196 

1197 def get_bob(self, _jid, _node, _ifrom, cid: str) -> BitsOfBinary | None: 

1198 stored = self.get(cid) 

1199 if stored is None: 

1200 return None 

1201 bob = BitsOfBinary() 

1202 bob["data"] = (self.root_dir / stored.file_name).read_bytes() 

1203 if stored.content_type is not None: 

1204 bob["type"] = stored.content_type 

1205 bob["cid"] = cid 

1206 return bob 

1207 

1208 def del_bob(self, _jid, _node, _ifrom, cid: str) -> None: 

1209 with self.session() as orm: 

1210 try: 

1211 file_name = orm.scalar( 

1212 delete(Bob) 

1213 .where(self.__get_condition(cid)) 

1214 .returning(Bob.file_name) 

1215 ) 

1216 except ValueError: 

1217 log.warning("Cannot delete Bob with CID: %s", cid) 

1218 return None 

1219 if file_name is None: 

1220 log.warning("No BoB with CID: %s", cid) 

1221 return None 

1222 (self.root_dir / file_name).unlink() 

1223 orm.commit() 

1224 

1225 def set_bob(self, _jid, _node, _ifrom, bob: BitsOfBinary) -> None: 

1226 cid = bob["cid"] 

1227 try: 

1228 alg_name, digest = self.__split_cid(cid) 

1229 except ValueError: 

1230 log.warning("Cannot set Bob with CID: %s", cid) 

1231 return 

1232 attr = self._ATTR_MAP.get(alg_name) 

1233 if attr is None: 

1234 log.warning("Cannot set BoB with unknown hash algo: %s", alg_name) 

1235 return None 

1236 with self.session() as orm: 

1237 existing = self.get(bob["cid"]) 

1238 if existing is not None: 

1239 log.debug("Bob already known") 

1240 return 

1241 bytes_ = bob["data"] 

1242 path = self.root_dir / uuid.uuid4().hex 

1243 if bob["type"]: 

1244 path = path.with_suffix(guess_extension(bob["type"]) or "") 

1245 path.write_bytes(bytes_) 

1246 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()} 

1247 if hashes[attr] != digest: 

1248 raise ValueError( 

1249 "The given CID does not correspond to the result of our hash" 

1250 ) 

1251 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes) 

1252 orm.add(row) 

1253 orm.commit() 

1254 

1255 

1256log = logging.getLogger(__name__) 

1257_session: Optional[Session] = None