Coverage for slidge/db/store.py: 88%

315 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-26 19:34 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import logging 

5import shutil 

6import uuid 

7from datetime import datetime, timedelta, timezone 

8from mimetypes import guess_extension 

9from typing import Collection, Iterator, Optional, Type 

10 

11import sqlalchemy as sa 

12from slixmpp.exceptions import XMPPError 

13from slixmpp.plugins.xep_0231.stanza import BitsOfBinary 

14from sqlalchemy import Engine, delete, event, select, update 

15from sqlalchemy.exc import InvalidRequestError 

16from sqlalchemy.orm import Session, attributes, sessionmaker 

17 

18from ..core import config 

19from ..util.archive_msg import HistoryMessage 

20from ..util.types import MamMetadata, Sticker 

21from .meta import Base 

22from .models import ( 

23 ArchivedMessage, 

24 ArchivedMessageSource, 

25 Avatar, 

26 Bob, 

27 Contact, 

28 ContactSent, 

29 DirectMessages, 

30 DirectThreads, 

31 GatewayUser, 

32 GroupMessages, 

33 GroupMessagesOrigin, 

34 GroupThreads, 

35 Participant, 

36 Room, 

37) 

38 

39 

40class UpdatedMixin: 

41 model: Type[Base] = NotImplemented 

42 

43 def __init__(self, session: Session) -> None: 

44 session.execute(update(self.model).values(updated=False)) 

45 

46 def get_by_pk(self, session: Session, pk: int) -> Type[Base]: 

47 stmt = select(self.model).where(self.model.id == pk) # type:ignore 

48 return session.scalar(stmt) 

49 

50 

51class SlidgeStore: 

52 def __init__(self, engine: Engine) -> None: 

53 self._engine = engine 

54 self.session = sessionmaker(engine) 

55 

56 self.users = UserStore(self.session) 

57 self.avatars = AvatarStore(self.session) 

58 self.id_map = IdMapStore() 

59 self.bob = BobStore() 

60 with self.session() as session: 

61 self.contacts = ContactStore(session) 

62 self.mam = MAMStore(session, self.session) 

63 self.rooms = RoomStore(session) 

64 self.participants = ParticipantStore(session) 

65 session.commit() 

66 

67 

68class UserStore: 

69 def __init__(self, session_maker) -> None: 

70 self.session = session_maker 

71 

72 def update(self, user: GatewayUser) -> None: 

73 with self.session(expire_on_commit=False) as session: 

74 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473 

75 try: 

76 attributes.flag_modified(user, "legacy_module_data") 

77 attributes.flag_modified(user, "preferences") 

78 except InvalidRequestError: 

79 pass 

80 session.add(user) 

81 session.commit() 

82 

83 

84class AvatarStore: 

85 def __init__(self, session_maker) -> None: 

86 self.session = session_maker 

87 

88 

89LegacyToXmppType = ( 

90 Type[DirectMessages] 

91 | Type[DirectThreads] 

92 | Type[GroupMessages] 

93 | Type[GroupThreads] 

94 | Type[GroupMessagesOrigin] 

95) 

96 

97 

98class IdMapStore: 

99 @staticmethod 

100 def _set( 

101 session: Session, 

102 foreign_key: int, 

103 legacy_id: str, 

104 xmpp_ids: list[str], 

105 type_: LegacyToXmppType, 

106 ) -> None: 

107 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id) 

108 ids = session.scalars( 

109 select(type_.id).filter( 

110 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id 

111 ) 

112 ) 

113 if ids: 

114 log.debug("Resetting legacy ID %s", legacy_id) 

115 session.execute(delete(type_).where(type_.id.in_(ids))) 

116 for xmpp_id in xmpp_ids: 

117 msg = type_(xmpp_id=xmpp_id, **kwargs) 

118 session.add(msg) 

119 

120 def set_thread( 

121 self, 

122 session: Session, 

123 foreign_key: int, 

124 legacy_id: str, 

125 xmpp_id: str, 

126 group: bool, 

127 ) -> None: 

128 self._set( 

129 session, 

130 foreign_key, 

131 legacy_id, 

132 [xmpp_id], 

133 GroupThreads if group else DirectThreads, 

134 ) 

135 

136 def set_msg( 

137 self, 

138 session: Session, 

139 foreign_key: int, 

140 legacy_id: str, 

141 xmpp_ids: list[str], 

142 group: bool, 

143 ) -> None: 

144 self._set( 

145 session, 

146 foreign_key, 

147 legacy_id, 

148 xmpp_ids, 

149 GroupMessages if group else DirectMessages, 

150 ) 

151 

152 def set_origin( 

153 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str 

154 ) -> None: 

155 self._set( 

156 session, 

157 foreign_key, 

158 legacy_id, 

159 [xmpp_id], 

160 GroupMessagesOrigin, 

161 ) 

162 

163 def get_origin( 

164 self, session: Session, foreign_key: int, legacy_id: str 

165 ) -> list[str]: 

166 return self._get( 

167 session, 

168 foreign_key, 

169 legacy_id, 

170 GroupMessagesOrigin, 

171 ) 

172 

173 @staticmethod 

174 def _get( 

175 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType 

176 ) -> list[str]: 

177 return list( 

178 session.scalars( 

179 select(type_.xmpp_id).filter_by( 

180 foreign_key=foreign_key, legacy_id=legacy_id 

181 ) 

182 ) 

183 ) 

184 

185 def get_xmpp( 

186 self, session: Session, foreign_key: int, legacy_id: str, group: bool 

187 ) -> list[str]: 

188 return self._get( 

189 session, 

190 foreign_key, 

191 legacy_id, 

192 GroupMessages if group else DirectMessages, 

193 ) 

194 

195 @staticmethod 

196 def _get_legacy( 

197 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType 

198 ) -> Optional[str]: 

199 return session.scalar( 

200 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id) 

201 ) 

202 

203 def get_legacy( 

204 self, 

205 session: Session, 

206 foreign_key: int, 

207 xmpp_id: str, 

208 group: bool, 

209 origin: bool = False, 

210 ) -> Optional[str]: 

211 if origin and group: 

212 return self._get_legacy( 

213 session, 

214 foreign_key, 

215 xmpp_id, 

216 GroupMessagesOrigin, 

217 ) 

218 return self._get_legacy( 

219 session, 

220 foreign_key, 

221 xmpp_id, 

222 GroupMessages if group else DirectMessages, 

223 ) 

224 

225 def get_thread( 

226 self, session: Session, foreign_key: int, xmpp_id: str, group: bool 

227 ) -> Optional[str]: 

228 return self._get_legacy( 

229 session, 

230 foreign_key, 

231 xmpp_id, 

232 GroupThreads if group else DirectThreads, 

233 ) 

234 

235 @staticmethod 

236 def was_sent_by_user( 

237 session: Session, foreign_key: int, legacy_id: str, group: bool 

238 ) -> bool: 

239 type_ = GroupMessages if group else DirectMessages 

240 return ( 

241 session.scalar( 

242 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id) 

243 ) 

244 is not None 

245 ) 

246 

247 

248class ContactStore(UpdatedMixin): 

249 model = Contact 

250 

251 def __init__(self, session: Session) -> None: 

252 super().__init__(session) 

253 session.execute(update(Contact).values(cached_presence=False)) 

254 session.execute(update(Contact).values(caps_ver=None)) 

255 

256 @staticmethod 

257 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None: 

258 if ( 

259 session.query(ContactSent.id) 

260 .where(ContactSent.contact_id == contact_pk) 

261 .where(ContactSent.msg_id == msg_id) 

262 .first() 

263 ) is not None: 

264 log.warning("Contact %s has already sent message %s", contact_pk, msg_id) 

265 return 

266 new = ContactSent(contact_id=contact_pk, msg_id=msg_id) 

267 session.add(new) 

268 

269 @staticmethod 

270 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]: 

271 result = [] 

272 to_del = [] 

273 for row in session.execute( 

274 select(ContactSent) 

275 .where(ContactSent.contact_id == contact_pk) 

276 .order_by(ContactSent.id) 

277 ).scalars(): 

278 to_del.append(row.id) 

279 result.append(row.msg_id) 

280 if row.msg_id == msg_id: 

281 break 

282 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del))) 

283 return result 

284 

285 

286class MAMStore: 

287 def __init__(self, session: Session, session_maker) -> None: 

288 self.session = session_maker 

289 session.execute( 

290 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL) 

291 ) 

292 

293 @staticmethod 

294 def nuke_older_than(session: Session, days: int) -> None: 

295 session.execute( 

296 delete(ArchivedMessage).where( 

297 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days) 

298 ) 

299 ) 

300 

301 @staticmethod 

302 def add_message( 

303 session: Session, 

304 room_pk: int, 

305 message: HistoryMessage, 

306 archive_only: bool, 

307 legacy_msg_id: Optional[str], 

308 ) -> None: 

309 source = ( 

310 ArchivedMessageSource.BACKFILL 

311 if archive_only 

312 else ArchivedMessageSource.LIVE 

313 ) 

314 existing = session.execute( 

315 select(ArchivedMessage) 

316 .where(ArchivedMessage.room_id == room_pk) 

317 .where(ArchivedMessage.stanza_id == message.id) 

318 ).scalar() 

319 if existing is None and legacy_msg_id is not None: 

320 existing = session.execute( 

321 select(ArchivedMessage) 

322 .where(ArchivedMessage.room_id == room_pk) 

323 .where(ArchivedMessage.legacy_id == legacy_msg_id) 

324 ).scalar() 

325 if existing is not None: 

326 log.debug("Updating message %s in room %s", message.id, room_pk) 

327 existing.timestamp = message.when 

328 existing.stanza = str(message.stanza) 

329 existing.author_jid = message.stanza.get_from() 

330 existing.source = source 

331 existing.legacy_id = legacy_msg_id 

332 session.add(existing) 

333 return 

334 mam_msg = ArchivedMessage( 

335 stanza_id=message.id, 

336 timestamp=message.when, 

337 stanza=str(message.stanza), 

338 author_jid=message.stanza.get_from(), 

339 room_id=room_pk, 

340 source=source, 

341 legacy_id=legacy_msg_id, 

342 ) 

343 session.add(mam_msg) 

344 

345 @staticmethod 

346 def get_messages( 

347 session: Session, 

348 room_pk: int, 

349 start_date: Optional[datetime] = None, 

350 end_date: Optional[datetime] = None, 

351 before_id: Optional[str] = None, 

352 after_id: Optional[str] = None, 

353 ids: Collection[str] = (), 

354 last_page_n: Optional[int] = None, 

355 sender: Optional[str] = None, 

356 flip: bool = False, 

357 ) -> Iterator[HistoryMessage]: 

358 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

359 if start_date is not None: 

360 q = q.where(ArchivedMessage.timestamp >= start_date) 

361 if end_date is not None: 

362 q = q.where(ArchivedMessage.timestamp <= end_date) 

363 if before_id is not None: 

364 stamp = session.execute( 

365 select(ArchivedMessage.timestamp).where( 

366 ArchivedMessage.stanza_id == before_id, 

367 ArchivedMessage.room_id == room_pk, 

368 ) 

369 ).scalar_one_or_none() 

370 if stamp is None: 

371 raise XMPPError( 

372 "item-not-found", 

373 f"Message {before_id} not found", 

374 ) 

375 q = q.where(ArchivedMessage.timestamp < stamp) 

376 if after_id is not None: 

377 stamp = session.execute( 

378 select(ArchivedMessage.timestamp).where( 

379 ArchivedMessage.stanza_id == after_id, 

380 ArchivedMessage.room_id == room_pk, 

381 ) 

382 ).scalar_one_or_none() 

383 if stamp is None: 

384 raise XMPPError( 

385 "item-not-found", 

386 f"Message {after_id} not found", 

387 ) 

388 q = q.where(ArchivedMessage.timestamp > stamp) 

389 if ids: 

390 q = q.filter(ArchivedMessage.stanza_id.in_(ids)) 

391 if sender is not None: 

392 q = q.where(ArchivedMessage.author_jid == sender) 

393 if flip: 

394 q = q.order_by(ArchivedMessage.timestamp.desc()) 

395 else: 

396 q = q.order_by(ArchivedMessage.timestamp.asc()) 

397 msgs = list(session.execute(q).scalars()) 

398 if ids and len(msgs) != len(ids): 

399 raise XMPPError( 

400 "item-not-found", 

401 "One of the requested messages IDs could not be found " 

402 "with the given constraints.", 

403 ) 

404 if last_page_n is not None: 

405 if flip: 

406 msgs = msgs[:last_page_n] 

407 else: 

408 msgs = msgs[-last_page_n:] 

409 for h in msgs: 

410 yield HistoryMessage( 

411 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc) 

412 ) 

413 

414 @staticmethod 

415 def get_first( 

416 session: Session, room_pk: int, with_legacy_id: bool = False 

417 ) -> Optional[ArchivedMessage]: 

418 q = ( 

419 select(ArchivedMessage) 

420 .where(ArchivedMessage.room_id == room_pk) 

421 .order_by(ArchivedMessage.timestamp.asc()) 

422 ) 

423 if with_legacy_id: 

424 q = q.filter(ArchivedMessage.legacy_id.isnot(None)) 

425 return session.execute(q).scalar() 

426 

427 @staticmethod 

428 def get_last( 

429 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None 

430 ) -> Optional[ArchivedMessage]: 

431 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

432 

433 if source is not None: 

434 q = q.where(ArchivedMessage.source == source) 

435 

436 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

437 

438 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]: 

439 r = [] 

440 first = self.get_first(session, room_pk) 

441 if first is not None: 

442 r.append(MamMetadata(first.stanza_id, first.timestamp)) 

443 last = self.get_last(session, room_pk) 

444 if last is not None: 

445 r.append(MamMetadata(last.stanza_id, last.timestamp)) 

446 return r 

447 

448 @staticmethod 

449 def get_most_recent_with_legacy_id( 

450 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None 

451 ) -> Optional[ArchivedMessage]: 

452 q = ( 

453 select(ArchivedMessage) 

454 .where(ArchivedMessage.room_id == room_pk) 

455 .where(ArchivedMessage.legacy_id.isnot(None)) 

456 ) 

457 if source is not None: 

458 q = q.where(ArchivedMessage.source == source) 

459 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

460 

461 @staticmethod 

462 def get_least_recent_with_legacy_id_after( 

463 session: Session, 

464 room_pk: int, 

465 after_id: str, 

466 source: ArchivedMessageSource = ArchivedMessageSource.LIVE, 

467 ) -> Optional[ArchivedMessage]: 

468 after_timestamp = ( 

469 session.query(ArchivedMessage.timestamp) 

470 .filter(ArchivedMessage.room_id == room_pk) 

471 .filter(ArchivedMessage.legacy_id == after_id) 

472 .scalar() 

473 ) 

474 q = ( 

475 select(ArchivedMessage) 

476 .where(ArchivedMessage.room_id == room_pk) 

477 .where(ArchivedMessage.legacy_id.isnot(None)) 

478 .where(ArchivedMessage.source == source) 

479 .where(ArchivedMessage.timestamp > after_timestamp) 

480 ) 

481 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar() 

482 

483 @staticmethod 

484 def get_by_legacy_id( 

485 session: Session, room_pk: int, legacy_id: str 

486 ) -> Optional[ArchivedMessage]: 

487 return ( 

488 session.query(ArchivedMessage) 

489 .filter(ArchivedMessage.room_id == room_pk) 

490 .filter(ArchivedMessage.legacy_id == legacy_id) 

491 .first() 

492 ) 

493 

494 

495class RoomStore(UpdatedMixin): 

496 model = Room 

497 

498 def __init__(self, session: Session) -> None: 

499 super().__init__(session) 

500 session.execute( 

501 update(Room).values( 

502 subject_setter=None, 

503 user_resources=None, 

504 history_filled=False, 

505 participants_filled=False, 

506 ) 

507 ) 

508 

509 @staticmethod 

510 def get_all(session: Session, user_pk: int) -> Iterator[Room]: 

511 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk)) 

512 

513 

514class ParticipantStore: 

515 def __init__(self, session: Session) -> None: 

516 session.execute(delete(Participant)) 

517 

518 @staticmethod 

519 def get_all( 

520 session: Session, room_pk: int, user_included: bool = True 

521 ) -> Iterator[Participant]: 

522 query = select(Participant).where(Participant.room_id == room_pk) 

523 if not user_included: 

524 query = query.where(~Participant.is_user) 

525 yield from session.scalars(query).unique() 

526 

527 

528class BobStore: 

529 _ATTR_MAP = { 

530 "sha-1": "sha_1", 

531 "sha1": "sha_1", 

532 "sha-256": "sha_256", 

533 "sha256": "sha_256", 

534 "sha-512": "sha_512", 

535 "sha512": "sha_512", 

536 } 

537 

538 _ALG_MAP = { 

539 "sha_1": hashlib.sha1, 

540 "sha_256": hashlib.sha256, 

541 "sha_512": hashlib.sha512, 

542 } 

543 

544 def __init__(self) -> None: 

545 if (config.HOME_DIR / "slidge_stickers").exists(): 

546 shutil.move( 

547 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store" 

548 ) 

549 self.root_dir = config.HOME_DIR / "bob_store" 

550 self.root_dir.mkdir(exist_ok=True) 

551 

552 @staticmethod 

553 def __split_cid(cid: str) -> list[str]: 

554 return cid.removesuffix("@bob.xmpp.org").split("+") 

555 

556 def __get_condition(self, cid: str): 

557 alg_name, digest = self.__split_cid(cid) 

558 attr = self._ATTR_MAP.get(alg_name) 

559 if attr is None: 

560 log.warning("Unknown hash algorithm: %s", alg_name) 

561 return None 

562 return getattr(Bob, attr) == digest 

563 

564 def get(self, session: Session, cid: str) -> Bob | None: 

565 try: 

566 return session.query(Bob).filter(self.__get_condition(cid)).scalar() 

567 except ValueError: 

568 log.warning("Cannot get Bob with CID: %s", cid) 

569 return None 

570 

571 def get_sticker(self, session: Session, cid: str) -> Sticker | None: 

572 bob = self.get(session, cid) 

573 if bob is None: 

574 return None 

575 return Sticker( 

576 self.root_dir / bob.file_name, 

577 bob.content_type, 

578 {h: getattr(bob, h) for h in self._ALG_MAP}, 

579 ) 

580 

581 def get_bob( 

582 self, session: Session, _jid, _node, _ifrom, cid: str 

583 ) -> BitsOfBinary | None: 

584 stored = self.get(session, cid) 

585 if stored is None: 

586 return None 

587 bob = BitsOfBinary() 

588 bob["data"] = (self.root_dir / stored.file_name).read_bytes() 

589 if stored.content_type is not None: 

590 bob["type"] = stored.content_type 

591 bob["cid"] = cid 

592 return bob 

593 

594 def del_bob(self, session: Session, _jid, _node, _ifrom, cid: str) -> None: 

595 try: 

596 file_name = session.scalar( 

597 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name) 

598 ) 

599 except ValueError: 

600 log.warning("Cannot delete Bob with CID: %s", cid) 

601 return None 

602 if file_name is None: 

603 log.warning("No BoB with CID: %s", cid) 

604 return None 

605 (self.root_dir / file_name).unlink() 

606 

607 def set_bob(self, session: Session, _jid, _node, _ifrom, bob: BitsOfBinary) -> None: 

608 cid = bob["cid"] 

609 try: 

610 alg_name, digest = self.__split_cid(cid) 

611 except ValueError: 

612 log.warning("Invalid CID provided: %s", cid) 

613 return 

614 attr = self._ATTR_MAP.get(alg_name) 

615 if attr is None: 

616 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name) 

617 return 

618 existing = self.get(session, bob["cid"]) 

619 if existing: 

620 log.debug("Bob already exists") 

621 return 

622 bytes_ = bob["data"] 

623 path = self.root_dir / uuid.uuid4().hex 

624 if bob["type"]: 

625 path = path.with_suffix(guess_extension(bob["type"]) or "") 

626 path.write_bytes(bytes_) 

627 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()} 

628 if hashes[attr] != digest: 

629 path.unlink(missing_ok=True) 

630 raise ValueError("Provided CID does not match calculated hash") 

631 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes) 

632 session.add(row) 

633 

634 

635@event.listens_for(sa.orm.Session, "after_flush") 

636def _check_avatar_orphans(session, flush_context): 

637 if not session.deleted: 

638 return 

639 

640 potentially_orphaned = set() 

641 for obj in session.deleted: 

642 if isinstance(obj, (Contact, Room)) and obj.avatar_id: 

643 potentially_orphaned.add(obj.avatar_id) 

644 if not potentially_orphaned: 

645 return 

646 

647 result = session.execute( 

648 sa.delete(Avatar).where( 

649 sa.and_( 

650 Avatar.id.in_(potentially_orphaned), 

651 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)), 

652 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)), 

653 ) 

654 ) 

655 ) 

656 deleted_count = result.rowcount 

657 log.debug(f"Auto-deleted %s orphaned avatars", deleted_count) 

658 

659 

660log = logging.getLogger(__name__)