Coverage for slidge / db / store.py: 90%

388 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 04:38 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import logging 

5import shutil 

6import uuid 

7from collections.abc import Callable, Collection, Iterable, Iterator 

8from datetime import UTC, datetime, timedelta 

9from mimetypes import guess_extension 

10from typing import Any, ClassVar 

11 

12import sqlalchemy as sa 

13from slixmpp.exceptions import XMPPError 

14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary 

15from sqlalchemy import ColumnElement, Engine, delete, event, select, update 

16from sqlalchemy.exc import InvalidRequestError 

17from sqlalchemy.orm import Session, attributes, joinedload, load_only, sessionmaker 

18 

19from ..core import config 

20from ..util.archive_msg import HistoryMessage 

21from ..util.types import MamMetadata, Sticker 

22from .meta import Base 

23from .models import ( 

24 ArchivedMessage, 

25 ArchivedMessageSource, 

26 Avatar, 

27 Bob, 

28 Contact, 

29 ContactSent, 

30 DirectMessages, 

31 DirectThreads, 

32 GatewayUser, 

33 GroupMessages, 

34 GroupMessagesOrigin, 

35 GroupThreads, 

36 Participant, 

37 Room, 

38 Space, 

39) 

40 

41 

42class UpdatedMixin: 

43 model: type[Base] = NotImplemented 

44 

45 def __init__(self, session: Session) -> None: 

46 self.reset_updated(session) 

47 

48 def get_by_pk(self, session: Session, pk: int) -> type[Base]: 

49 stmt = select(self.model).where(self.model.id == pk) # type:ignore[attr-defined] 

50 return session.scalar(stmt) # type:ignore[no-any-return] 

51 

52 def reset_updated(self, session: Session) -> None: 

53 session.execute(update(self.model).values(updated=False)) 

54 

55 

56class SlidgeStore: 

57 def __init__(self, engine: Engine) -> None: 

58 self._engine = engine 

59 self.session = sessionmaker[Any](engine) 

60 

61 self.users = UserStore(self.session) 

62 self.avatars = AvatarStore(self.session) 

63 self.id_map = IdMapStore() 

64 self.bob = BobStore() 

65 with self.session() as session: 

66 self.contacts = ContactStore(session) 

67 self.mam = MAMStore(session, self.session) 

68 self.rooms = RoomStore(session) 

69 self.participants = ParticipantStore(session) 

70 self.spaces = SpaceStore(session) 

71 session.commit() 

72 

73 

74class UserStore: 

75 def __init__(self, session_maker: sessionmaker[Any]) -> None: 

76 self.session = session_maker 

77 

78 def update(self, user: GatewayUser) -> None: 

79 with self.session(expire_on_commit=False) as session: 

80 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473 

81 try: 

82 attributes.flag_modified(user, "legacy_module_data") 

83 attributes.flag_modified(user, "preferences") 

84 except InvalidRequestError: 

85 pass 

86 session.add(user) 

87 session.commit() 

88 

89 

90class AvatarStore: 

91 def __init__(self, session_maker: sessionmaker[Any]) -> None: 

92 self.session = session_maker 

93 

94 

95LegacyToXmppType = ( 

96 type[DirectMessages] 

97 | type[DirectThreads] 

98 | type[GroupMessages] 

99 | type[GroupThreads] 

100 | type[GroupMessagesOrigin] 

101) 

102 

103 

104class IdMapStore: 

105 @staticmethod 

106 def _set( 

107 session: Session, 

108 foreign_key: int, 

109 legacy_id: str, 

110 xmpp_ids: list[str], 

111 type_: LegacyToXmppType, 

112 ) -> None: 

113 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id) 

114 ids = list( 

115 session.scalars( 

116 select(type_.id).filter( 

117 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id 

118 ) 

119 ) 

120 ) 

121 if ids: 

122 log.debug("Resetting legacy ID %s", legacy_id) 

123 session.execute(delete(type_).where(type_.id.in_(ids))) 

124 for xmpp_id in xmpp_ids: 

125 msg = type_(xmpp_id=xmpp_id, **kwargs) 

126 session.add(msg) 

127 

128 def set_thread( 

129 self, 

130 session: Session, 

131 foreign_key: int, 

132 legacy_id: str, 

133 xmpp_id: str, 

134 group: bool, 

135 ) -> None: 

136 self._set( 

137 session, 

138 foreign_key, 

139 legacy_id, 

140 [xmpp_id], 

141 GroupThreads if group else DirectThreads, 

142 ) 

143 

144 def set_msg( 

145 self, 

146 session: Session, 

147 foreign_key: int, 

148 legacy_id: str, 

149 xmpp_ids: list[str], 

150 group: bool, 

151 ) -> None: 

152 self._set( 

153 session, 

154 foreign_key, 

155 legacy_id, 

156 xmpp_ids, 

157 GroupMessages if group else DirectMessages, 

158 ) 

159 

160 def set_origin( 

161 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str 

162 ) -> None: 

163 self._set( 

164 session, 

165 foreign_key, 

166 legacy_id, 

167 [xmpp_id], 

168 GroupMessagesOrigin, 

169 ) 

170 

171 def get_origin( 

172 self, session: Session, foreign_key: int, legacy_id: str 

173 ) -> list[str]: 

174 return self._get( 

175 session, 

176 foreign_key, 

177 legacy_id, 

178 GroupMessagesOrigin, 

179 ) 

180 

181 @staticmethod 

182 def _get( 

183 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType 

184 ) -> list[str]: 

185 return list( 

186 session.scalars( 

187 select(type_.xmpp_id).filter_by( 

188 foreign_key=foreign_key, legacy_id=str(legacy_id) 

189 ) 

190 ) 

191 ) 

192 

193 def get_xmpp( 

194 self, session: Session, foreign_key: int, legacy_id: str, group: bool 

195 ) -> list[str]: 

196 return self._get( 

197 session, 

198 foreign_key, 

199 legacy_id, 

200 GroupMessages if group else DirectMessages, 

201 ) 

202 

203 @staticmethod 

204 def _get_legacy( 

205 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType 

206 ) -> str | None: 

207 return session.scalar( 

208 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id) 

209 ) 

210 

211 def get_legacy( 

212 self, 

213 session: Session, 

214 foreign_key: int, 

215 xmpp_id: str, 

216 group: bool, 

217 origin: bool = False, 

218 ) -> str | None: 

219 if origin and group: 

220 return self._get_legacy( 

221 session, 

222 foreign_key, 

223 xmpp_id, 

224 GroupMessagesOrigin, 

225 ) 

226 return self._get_legacy( 

227 session, 

228 foreign_key, 

229 xmpp_id, 

230 GroupMessages if group else DirectMessages, 

231 ) 

232 

233 def get_thread( 

234 self, session: Session, foreign_key: int, xmpp_id: str, group: bool 

235 ) -> str | None: 

236 return self._get_legacy( 

237 session, 

238 foreign_key, 

239 xmpp_id, 

240 GroupThreads if group else DirectThreads, 

241 ) 

242 

243 @staticmethod 

244 def was_sent_by_user( 

245 session: Session, foreign_key: int, legacy_id: str, group: bool 

246 ) -> bool: 

247 type_ = GroupMessages if group else DirectMessages 

248 return ( 

249 session.scalar( 

250 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id) 

251 ) 

252 is not None 

253 ) 

254 

255 

256class ContactStore(UpdatedMixin): 

257 model = Contact 

258 

259 def __init__(self, session: Session) -> None: 

260 super().__init__(session) 

261 session.execute(update(Contact).values(cached_presence=False)) 

262 session.execute(update(Contact).values(caps_ver=None)) 

263 

264 @staticmethod 

265 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None: 

266 if ( 

267 session.query(ContactSent.id) 

268 .where(ContactSent.contact_id == contact_pk) 

269 .where(ContactSent.msg_id == msg_id) 

270 .first() 

271 ) is not None: 

272 log.warning("Contact %s has already sent message %s", contact_pk, msg_id) 

273 return 

274 new = ContactSent(contact_id=contact_pk, msg_id=msg_id) 

275 session.add(new) 

276 

277 @staticmethod 

278 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]: 

279 result = [] 

280 to_del = [] 

281 for row in session.execute( 

282 select(ContactSent) 

283 .where(ContactSent.contact_id == contact_pk) 

284 .order_by(ContactSent.id) 

285 ).scalars(): 

286 to_del.append(row.id) 

287 result.append(row.msg_id) 

288 if row.msg_id == msg_id: 

289 break 

290 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del))) 

291 return result 

292 

293 

294class MAMStore: 

295 def __init__(self, session: Session, session_maker: sessionmaker[Any]) -> None: 

296 self.session = session_maker 

297 self.reset_source(session) 

298 

299 @staticmethod 

300 def reset_source(session: Session) -> None: 

301 session.execute( 

302 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL) 

303 ) 

304 

305 @staticmethod 

306 def nuke_older_than(session: Session, days: int) -> None: 

307 session.execute( 

308 delete(ArchivedMessage).where( 

309 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days) 

310 ) 

311 ) 

312 

313 @staticmethod 

314 def add_message( 

315 session: Session, 

316 room_pk: int, 

317 message: HistoryMessage, 

318 archive_only: bool, 

319 legacy_msg_id: str | None, 

320 ) -> None: 

321 source = ( 

322 ArchivedMessageSource.BACKFILL 

323 if archive_only 

324 else ArchivedMessageSource.LIVE 

325 ) 

326 existing = session.execute( 

327 select(ArchivedMessage) 

328 .where(ArchivedMessage.room_id == room_pk) 

329 .where(ArchivedMessage.stanza_id == message.id) 

330 ).scalar() 

331 if existing is None and legacy_msg_id is not None: 

332 existing = session.execute( 

333 select(ArchivedMessage) 

334 .where(ArchivedMessage.room_id == room_pk) 

335 .where(ArchivedMessage.legacy_id == str(legacy_msg_id)) 

336 ).scalar() 

337 if existing is not None: 

338 log.debug("Updating message %s in room %s", message.id, room_pk) 

339 existing.timestamp = message.when 

340 existing.stanza = str(message.stanza) 

341 existing.author_jid = message.stanza.get_from() 

342 existing.source = source 

343 existing.legacy_id = legacy_msg_id 

344 session.add(existing) 

345 return 

346 mam_msg = ArchivedMessage( 

347 stanza_id=message.id, 

348 timestamp=message.when, 

349 stanza=str(message.stanza), 

350 author_jid=message.stanza.get_from(), 

351 room_id=room_pk, 

352 source=source, 

353 legacy_id=legacy_msg_id, 

354 ) 

355 session.add(mam_msg) 

356 

357 @staticmethod 

358 def get_messages( 

359 session: Session, 

360 room_pk: int, 

361 start_date: datetime | None = None, 

362 end_date: datetime | None = None, 

363 before_id: str | None = None, 

364 after_id: str | None = None, 

365 ids: Collection[str] = (), 

366 last_page_n: int | None = None, 

367 sender: str | None = None, 

368 flip: bool = False, 

369 ) -> Iterator[HistoryMessage]: 

370 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

371 if start_date is not None: 

372 q = q.where(ArchivedMessage.timestamp >= start_date) 

373 if end_date is not None: 

374 q = q.where(ArchivedMessage.timestamp <= end_date) 

375 if before_id is not None: 

376 stamp = session.execute( 

377 select(ArchivedMessage.timestamp).where( 

378 ArchivedMessage.stanza_id == before_id, 

379 ArchivedMessage.room_id == room_pk, 

380 ) 

381 ).scalar_one_or_none() 

382 if stamp is None: 

383 raise XMPPError( 

384 "item-not-found", 

385 f"Message {before_id} not found", 

386 ) 

387 q = q.where(ArchivedMessage.timestamp < stamp) 

388 if after_id is not None: 

389 stamp = session.execute( 

390 select(ArchivedMessage.timestamp).where( 

391 ArchivedMessage.stanza_id == after_id, 

392 ArchivedMessage.room_id == room_pk, 

393 ) 

394 ).scalar_one_or_none() 

395 if stamp is None: 

396 raise XMPPError( 

397 "item-not-found", 

398 f"Message {after_id} not found", 

399 ) 

400 q = q.where(ArchivedMessage.timestamp > stamp) 

401 if ids: 

402 q = q.filter(ArchivedMessage.stanza_id.in_(ids)) 

403 if sender is not None: 

404 q = q.where(ArchivedMessage.author_jid == sender) 

405 if flip: 

406 q = q.order_by(ArchivedMessage.timestamp.desc()) 

407 else: 

408 q = q.order_by(ArchivedMessage.timestamp.asc()) 

409 msgs = list(session.execute(q).scalars()) 

410 if ids and len(msgs) != len(ids): 

411 raise XMPPError( 

412 "item-not-found", 

413 "One of the requested messages IDs could not be found " 

414 "with the given constraints.", 

415 ) 

416 if last_page_n is not None: 

417 msgs = msgs[:last_page_n] if flip else msgs[-last_page_n:] 

418 for h in msgs: 

419 yield HistoryMessage( 

420 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=UTC) 

421 ) 

422 

423 @staticmethod 

424 def get_first( 

425 session: Session, room_pk: int, with_legacy_id: bool = False 

426 ) -> ArchivedMessage | None: 

427 q = ( 

428 select(ArchivedMessage) 

429 .where(ArchivedMessage.room_id == room_pk) 

430 .order_by(ArchivedMessage.timestamp.asc()) 

431 ) 

432 if with_legacy_id: 

433 q = q.filter(ArchivedMessage.legacy_id.isnot(None)) 

434 return session.execute(q).scalar() 

435 

436 @staticmethod 

437 def get_last( 

438 session: Session, room_pk: int, source: ArchivedMessageSource | None = None 

439 ) -> ArchivedMessage | None: 

440 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

441 

442 if source is not None: 

443 q = q.where(ArchivedMessage.source == source) 

444 

445 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

446 

447 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]: 

448 r = [] 

449 first = self.get_first(session, room_pk) 

450 if first is not None: 

451 r.append(MamMetadata(first.stanza_id, first.timestamp)) 

452 last = self.get_last(session, room_pk) 

453 if last is not None: 

454 r.append(MamMetadata(last.stanza_id, last.timestamp)) 

455 return r 

456 

457 @staticmethod 

458 def get_most_recent_with_legacy_id( 

459 session: Session, room_pk: int, source: ArchivedMessageSource | None = None 

460 ) -> ArchivedMessage | None: 

461 q = ( 

462 select(ArchivedMessage) 

463 .where(ArchivedMessage.room_id == room_pk) 

464 .where(ArchivedMessage.legacy_id.isnot(None)) 

465 ) 

466 if source is not None: 

467 q = q.where(ArchivedMessage.source == source) 

468 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

469 

470 @staticmethod 

471 def get_least_recent_with_legacy_id_after( 

472 session: Session, 

473 room_pk: int, 

474 after_id: str, 

475 source: ArchivedMessageSource = ArchivedMessageSource.LIVE, 

476 ) -> ArchivedMessage | None: 

477 after_timestamp = ( 

478 session.query(ArchivedMessage.timestamp) 

479 .filter(ArchivedMessage.room_id == room_pk) 

480 .filter(ArchivedMessage.legacy_id == after_id) 

481 .scalar() 

482 ) 

483 q = ( 

484 select(ArchivedMessage) 

485 .where(ArchivedMessage.room_id == room_pk) 

486 .where(ArchivedMessage.legacy_id.isnot(None)) 

487 .where(ArchivedMessage.source == source) 

488 .where(ArchivedMessage.timestamp > after_timestamp) 

489 ) 

490 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar() 

491 

492 @staticmethod 

493 def get_by_legacy_id( 

494 session: Session, room_pk: int, legacy_id: str 

495 ) -> ArchivedMessage | None: 

496 return ( 

497 session.query(ArchivedMessage) 

498 .filter(ArchivedMessage.room_id == room_pk) 

499 .filter(ArchivedMessage.legacy_id == legacy_id) 

500 .first() 

501 ) 

502 

503 @staticmethod 

504 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]: 

505 q = ( 

506 select(ArchivedMessage.id, ArchivedMessage.stanza_id) 

507 .where(ArchivedMessage.room_id == room_pk) 

508 .where(~ArchivedMessage.displayed_by_user) 

509 .where(ArchivedMessage.legacy_id.is_not(None)) 

510 .order_by(ArchivedMessage.timestamp.asc()) 

511 ) 

512 

513 ref = session.scalar( 

514 select(ArchivedMessage) 

515 .where(ArchivedMessage.room_id == room_pk) 

516 .where(ArchivedMessage.stanza_id == stanza_id) 

517 ) 

518 

519 if ref is None: 

520 log.debug( 

521 "(pop unread in muc): message not found, returning all MAM messages." 

522 ) 

523 rows = session.execute(q) 

524 else: 

525 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp)) 

526 

527 pks: list[int] = [] 

528 stanza_ids: list[str] = [] 

529 

530 for id_, stanza_id in rows: 

531 pks.append(id_) 

532 stanza_ids.append(stanza_id) 

533 

534 session.execute( 

535 update(ArchivedMessage) 

536 .where(ArchivedMessage.id.in_(pks)) 

537 .values(displayed_by_user=True) 

538 ) 

539 return stanza_ids 

540 

541 @staticmethod 

542 def is_displayed_by_user( 

543 session: Session, room_jid: str, legacy_msg_id: str 

544 ) -> bool: 

545 return any( 

546 session.execute( 

547 select(ArchivedMessage.displayed_by_user) 

548 .join(Room) 

549 .where(Room.jid == room_jid) 

550 .where(ArchivedMessage.legacy_id == legacy_msg_id) 

551 ).scalars() 

552 ) 

553 

554 

555class RoomStore(UpdatedMixin): 

556 model = Room 

557 

558 def reset_updated(self, session: Session) -> None: 

559 super().reset_updated(session) 

560 session.execute( 

561 update(Room).values( 

562 subject_setter=None, 

563 user_resources=None, 

564 history_filled=False, 

565 participants_filled=False, 

566 ) 

567 ) 

568 

569 @staticmethod 

570 def get_all(session: Session, user_pk: int) -> Iterator[Room]: 

571 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk)) 

572 

573 @staticmethod 

574 def get(session: Session, user_pk: int, legacy_id: str) -> Room: 

575 return session.execute( 

576 select(Room) 

577 .where(Room.user_account_id == user_pk) 

578 .where(Room.legacy_id == legacy_id) 

579 ).scalar_one() 

580 

581 @staticmethod 

582 def nick_available(session: Session, room_pk: int, nickname: str) -> bool: 

583 return ( 

584 session.execute( 

585 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname) 

586 ) 

587 ).one_or_none() is None 

588 

589 

590class ParticipantStore: 

591 def __init__(self, session: Session) -> None: 

592 session.execute(delete(Participant)) 

593 

594 @staticmethod 

595 def get_all( 

596 session: Session, room_pk: int, user_included: bool = True 

597 ) -> Iterator[Participant]: 

598 query = select(Participant).where(Participant.room_id == room_pk) 

599 if not user_included: 

600 query = query.where(~Participant.is_user) 

601 yield from session.scalars(query).unique() 

602 

603 @staticmethod 

604 def delete(session: Session, pk: int) -> None: 

605 session.execute(delete(Participant).where(Participant.id == pk)) 

606 

607 

608class BobStore: 

609 _ATTR_MAP: ClassVar[dict[str, str]] = { 

610 "sha-1": "sha_1", 

611 "sha1": "sha_1", 

612 "sha-256": "sha_256", 

613 "sha256": "sha_256", 

614 "sha-512": "sha_512", 

615 "sha512": "sha_512", 

616 } 

617 

618 _ALG_MAP: ClassVar[dict[str, Callable[[bytes], hashlib._Hash]]] = { 

619 "sha_1": hashlib.sha1, 

620 "sha_256": hashlib.sha256, 

621 "sha_512": hashlib.sha512, 

622 } 

623 

624 def __init__(self) -> None: 

625 if (config.HOME_DIR / "slidge_stickers").exists(): 

626 shutil.move( 

627 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store" 

628 ) 

629 self.root_dir = config.HOME_DIR / "bob_store" 

630 self.root_dir.mkdir(exist_ok=True) 

631 

632 @staticmethod 

633 def __split_cid(cid: str) -> list[str]: 

634 return cid.removesuffix("@bob.xmpp.org").split("+") 

635 

636 def __get_condition(self, cid: str) -> ColumnElement[bool]: 

637 alg_name, digest = self.__split_cid(cid) 

638 attr = self._ATTR_MAP.get(alg_name) 

639 if attr is None: 

640 log.warning("Unknown hash algorithm: %s", alg_name) 

641 raise ValueError 

642 return getattr(Bob, attr) == digest # type:ignore[no-any-return] 

643 

644 def get(self, session: Session, cid: str) -> Bob | None: 

645 try: 

646 return session.query(Bob).filter(self.__get_condition(cid)).scalar() # type:ignore[no-any-return] 

647 except ValueError: 

648 log.warning("Cannot get Bob with CID: %s", cid) 

649 return None 

650 

651 def get_sticker(self, session: Session, cid: str) -> Sticker | None: 

652 bob = self.get(session, cid) 

653 if bob is None: 

654 return None 

655 return self.__sticker_from_bob(bob) 

656 

657 def __sticker_from_bob(self, bob: Bob) -> Sticker: 

658 return Sticker( 

659 self.root_dir / bob.file_name, 

660 bob.content_type, 

661 {h: getattr(bob, h) for h in self._ALG_MAP}, 

662 ) 

663 

664 def get_bob( 

665 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str 

666 ) -> BitsOfBinary | None: 

667 stored = self.get(session, cid) 

668 if stored is None: 

669 return None 

670 bob = BitsOfBinary() 

671 bob["data"] = (self.root_dir / stored.file_name).read_bytes() 

672 if stored.content_type is not None: 

673 bob["type"] = stored.content_type 

674 bob["cid"] = cid 

675 return bob 

676 

677 def del_bob( 

678 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str 

679 ) -> None: 

680 try: 

681 file_name = session.scalar( 

682 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name) 

683 ) 

684 except ValueError: 

685 log.warning("Cannot delete Bob with CID: %s", cid) 

686 return None 

687 if file_name is None: 

688 log.warning("No BoB with CID: %s", cid) 

689 return None 

690 (self.root_dir / file_name).unlink() 

691 

692 def set_bob( 

693 self, 

694 session: Session, 

695 _jid: object, 

696 _node: object, 

697 _ifrom: object, 

698 bob: BitsOfBinary, 

699 ) -> Sticker | None: 

700 return self.set_sticker(session, bob["cid"], bob["data"], bob["type"]) 

701 

702 def set_sticker( 

703 self, 

704 session: Session, 

705 cid: str, 

706 bytes_: bytes, 

707 content_type: str | None, 

708 ) -> Sticker | None: 

709 try: 

710 alg_name, digest = self.__split_cid(cid) 

711 except ValueError: 

712 log.warning("Invalid CID provided: %s", cid) 

713 return None 

714 attr = self._ATTR_MAP.get(alg_name) 

715 if attr is None: 

716 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name) 

717 return None 

718 existing = self.get(session, cid) 

719 if existing: 

720 log.debug("Bob already exists") 

721 return None 

722 path = self.root_dir / uuid.uuid4().hex 

723 if content_type is None: 

724 try: 

725 import magic 

726 except ImportError: 

727 content_type = "application/octet-stream" 

728 else: 

729 content_type = magic.from_buffer(bytes_, mime=True) 

730 path = path.with_suffix(guess_extension(content_type) or "") 

731 path.write_bytes(bytes_) 

732 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()} 

733 if hashes[attr] != digest: 

734 path.unlink(missing_ok=True) 

735 raise ValueError("Provided CID does not match calculated hash") 

736 row = Bob(file_name=path.name, content_type=content_type, **hashes) 

737 session.add(row) 

738 return self.__sticker_from_bob(row) 

739 

740 

741class SpaceStore(UpdatedMixin): 

742 model = Space 

743 

744 def __init__(self, session: Session) -> None: 

745 session.execute(delete(Space)) 

746 

747 @staticmethod 

748 def add_or_get(session: Session, user_pk: int, legacy_id: str) -> Space: 

749 space = session.execute( 

750 select(Space) 

751 .where(Space.user_account_id == user_pk) 

752 .where(Space.legacy_id == legacy_id) 

753 ).scalar_one_or_none() 

754 if space is None: 

755 space = Space(user_account_id=user_pk, legacy_id=legacy_id) 

756 session.add(space) 

757 session.commit() 

758 return space 

759 

760 @staticmethod 

761 def get_all(session: Session, user_pk: int) -> Iterable[Space]: 

762 return session.execute( 

763 select(Space).where(Space.user_account_id == user_pk) 

764 ).scalars() 

765 

766 @staticmethod 

767 def get_by_legacy_id( 

768 session: Session, user_pk: int, legacy_id: str, full: bool = False 

769 ) -> Space | None: 

770 stmt = ( 

771 select(Space) 

772 .where(Space.user_account_id == user_pk) 

773 .where(Space.legacy_id == legacy_id) 

774 ) 

775 if full: 

776 stmt = stmt.options( 

777 joinedload(Space.owners), 

778 joinedload(Space.creator), 

779 ) 

780 return session.execute(stmt).unique().scalar_one_or_none() 

781 

782 @staticmethod 

783 def get_unupdated(session: Session, user_pk: int) -> list[Space]: 

784 return list( 

785 session.execute( 

786 select(Space) 

787 .where(Space.user_account_id == user_pk) 

788 .where(Space.updated.is_(False)) 

789 ).scalars() 

790 ) 

791 

792 @staticmethod 

793 def get_rooms( 

794 session: Session, 

795 user_pk: int, 

796 legacy_id: str, 

797 room_legacy_ids: Iterable[str] = (), 

798 ) -> list[Room]: 

799 q = ( 

800 select(Room) 

801 .join(Room.space) 

802 .where(Room.user_account_id == user_pk) 

803 .where(Space.legacy_id == legacy_id) 

804 .options(load_only(Room.jid, Room.name)) 

805 ) 

806 if room_legacy_ids: 

807 q = q.where(Room.legacy_id.in_(room_legacy_ids)) 

808 return list(session.execute(q).scalars()) 

809 

810 @staticmethod 

811 def exists(session: Session, user_pk: int, legacy_id: str) -> bool: 

812 return session.execute( 

813 select( 

814 sa.exists() 

815 .where(Space.user_account_id == user_pk) 

816 .where(Space.legacy_id == legacy_id) 

817 ) 

818 ).scalar_one() 

819 

820 

821@event.listens_for(sa.orm.Session, "after_flush") 

822def _check_avatar_orphans(session: Session, flush_context: sa.ExecutionContext) -> None: 

823 if not session.deleted: 

824 return 

825 

826 potentially_orphaned = set() 

827 for obj in session.deleted: 

828 if isinstance(obj, (Contact, Room)) and obj.avatar_id: 

829 potentially_orphaned.add(obj.avatar_id) 

830 if not potentially_orphaned: 

831 return 

832 

833 result = session.execute( 

834 sa.delete(Avatar).where( 

835 sa.and_( 

836 Avatar.id.in_(potentially_orphaned), 

837 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)), 

838 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)), 

839 ) 

840 ) 

841 ) 

842 deleted_count = result.rowcount # type:ignore[attr-defined] 

843 log.debug("Auto-deleted %s orphaned avatars", deleted_count) 

844 

845 

846log = logging.getLogger(__name__)