Coverage for slidge / db / store.py: 90%

388 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-20 19:56 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import logging 

5import shutil 

6import uuid 

7from collections.abc import Callable, Collection, Iterable, Iterator 

8from datetime import UTC, datetime, timedelta 

9from mimetypes import guess_extension 

10from typing import Any, ClassVar 

11 

12import sqlalchemy as sa 

13from slixmpp.exceptions import XMPPError 

14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary 

15from sqlalchemy import ColumnElement, Engine, delete, event, select, update 

16from sqlalchemy.exc import InvalidRequestError 

17from sqlalchemy.orm import Session, attributes, joinedload, load_only, sessionmaker 

18 

19from ..core import config 

20from ..util.archive_msg import HistoryMessage 

21from ..util.types import MamMetadata, Sticker 

22from .meta import Base 

23from .models import ( 

24 ArchivedMessage, 

25 ArchivedMessageSource, 

26 Avatar, 

27 Bob, 

28 Contact, 

29 ContactSent, 

30 DirectMessages, 

31 DirectThreads, 

32 GatewayUser, 

33 GroupMessages, 

34 GroupMessagesOrigin, 

35 GroupThreads, 

36 Participant, 

37 Room, 

38 Space, 

39) 

40 

41 

42class UpdatedMixin: 

43 model: type[Base] = NotImplemented 

44 

45 def __init__(self, session: Session) -> None: 

46 self.reset_updated(session) 

47 

48 def get_by_pk(self, session: Session, pk: int) -> type[Base]: 

49 stmt = select(self.model).where(self.model.id == pk) # type:ignore[attr-defined] 

50 return session.scalar(stmt) # type:ignore[no-any-return] 

51 

52 def reset_updated(self, session: Session) -> None: 

53 session.execute(update(self.model).values(updated=False)) 

54 

55 

56class SlidgeStore: 

57 def __init__(self, engine: Engine) -> None: 

58 self._engine = engine 

59 self.session = sessionmaker[Any](engine) 

60 

61 self.users = UserStore(self.session) 

62 self.avatars = AvatarStore(self.session) 

63 self.id_map = IdMapStore() 

64 self.bob = BobStore() 

65 with self.session() as session: 

66 self.contacts = ContactStore(session) 

67 self.mam = MAMStore(session, self.session) 

68 self.rooms = RoomStore(session) 

69 self.participants = ParticipantStore(session) 

70 self.spaces = SpaceStore(session) 

71 session.commit() 

72 

73 

74class UserStore: 

75 def __init__(self, session_maker: sessionmaker[Any]) -> None: 

76 self.session = session_maker 

77 

78 def update(self, user: GatewayUser) -> None: 

79 with self.session(expire_on_commit=False) as session: 

80 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473 

81 try: 

82 attributes.flag_modified(user, "legacy_module_data") 

83 attributes.flag_modified(user, "preferences") 

84 except InvalidRequestError: 

85 pass 

86 session.add(user) 

87 session.commit() 

88 

89 

90class AvatarStore: 

91 def __init__(self, session_maker: sessionmaker[Any]) -> None: 

92 self.session = session_maker 

93 

94 

95LegacyToXmppType = ( 

96 type[DirectMessages] 

97 | type[DirectThreads] 

98 | type[GroupMessages] 

99 | type[GroupThreads] 

100 | type[GroupMessagesOrigin] 

101) 

102 

103 

104class IdMapStore: 

105 @staticmethod 

106 def _set( 

107 session: Session, 

108 foreign_key: int, 

109 legacy_id: str, 

110 xmpp_ids: list[str], 

111 type_: LegacyToXmppType, 

112 ) -> None: 

113 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id) 

114 ids = session.scalars( 

115 select(type_.id).filter( 

116 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id 

117 ) 

118 ) 

119 if ids: 

120 log.debug("Resetting legacy ID %s", legacy_id) 

121 session.execute(delete(type_).where(type_.id.in_(ids))) 

122 for xmpp_id in xmpp_ids: 

123 msg = type_(xmpp_id=xmpp_id, **kwargs) 

124 session.add(msg) 

125 

126 def set_thread( 

127 self, 

128 session: Session, 

129 foreign_key: int, 

130 legacy_id: str, 

131 xmpp_id: str, 

132 group: bool, 

133 ) -> None: 

134 self._set( 

135 session, 

136 foreign_key, 

137 legacy_id, 

138 [xmpp_id], 

139 GroupThreads if group else DirectThreads, 

140 ) 

141 

142 def set_msg( 

143 self, 

144 session: Session, 

145 foreign_key: int, 

146 legacy_id: str, 

147 xmpp_ids: list[str], 

148 group: bool, 

149 ) -> None: 

150 self._set( 

151 session, 

152 foreign_key, 

153 legacy_id, 

154 xmpp_ids, 

155 GroupMessages if group else DirectMessages, 

156 ) 

157 

158 def set_origin( 

159 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str 

160 ) -> None: 

161 self._set( 

162 session, 

163 foreign_key, 

164 legacy_id, 

165 [xmpp_id], 

166 GroupMessagesOrigin, 

167 ) 

168 

169 def get_origin( 

170 self, session: Session, foreign_key: int, legacy_id: str 

171 ) -> list[str]: 

172 return self._get( 

173 session, 

174 foreign_key, 

175 legacy_id, 

176 GroupMessagesOrigin, 

177 ) 

178 

179 @staticmethod 

180 def _get( 

181 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType 

182 ) -> list[str]: 

183 return list( 

184 session.scalars( 

185 select(type_.xmpp_id).filter_by( 

186 foreign_key=foreign_key, legacy_id=str(legacy_id) 

187 ) 

188 ) 

189 ) 

190 

191 def get_xmpp( 

192 self, session: Session, foreign_key: int, legacy_id: str, group: bool 

193 ) -> list[str]: 

194 return self._get( 

195 session, 

196 foreign_key, 

197 legacy_id, 

198 GroupMessages if group else DirectMessages, 

199 ) 

200 

201 @staticmethod 

202 def _get_legacy( 

203 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType 

204 ) -> str | None: 

205 return session.scalar( 

206 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id) 

207 ) 

208 

209 def get_legacy( 

210 self, 

211 session: Session, 

212 foreign_key: int, 

213 xmpp_id: str, 

214 group: bool, 

215 origin: bool = False, 

216 ) -> str | None: 

217 if origin and group: 

218 return self._get_legacy( 

219 session, 

220 foreign_key, 

221 xmpp_id, 

222 GroupMessagesOrigin, 

223 ) 

224 return self._get_legacy( 

225 session, 

226 foreign_key, 

227 xmpp_id, 

228 GroupMessages if group else DirectMessages, 

229 ) 

230 

231 def get_thread( 

232 self, session: Session, foreign_key: int, xmpp_id: str, group: bool 

233 ) -> str | None: 

234 return self._get_legacy( 

235 session, 

236 foreign_key, 

237 xmpp_id, 

238 GroupThreads if group else DirectThreads, 

239 ) 

240 

241 @staticmethod 

242 def was_sent_by_user( 

243 session: Session, foreign_key: int, legacy_id: str, group: bool 

244 ) -> bool: 

245 type_ = GroupMessages if group else DirectMessages 

246 return ( 

247 session.scalar( 

248 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id) 

249 ) 

250 is not None 

251 ) 

252 

253 

254class ContactStore(UpdatedMixin): 

255 model = Contact 

256 

257 def __init__(self, session: Session) -> None: 

258 super().__init__(session) 

259 session.execute(update(Contact).values(cached_presence=False)) 

260 session.execute(update(Contact).values(caps_ver=None)) 

261 

262 @staticmethod 

263 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None: 

264 if ( 

265 session.query(ContactSent.id) 

266 .where(ContactSent.contact_id == contact_pk) 

267 .where(ContactSent.msg_id == msg_id) 

268 .first() 

269 ) is not None: 

270 log.warning("Contact %s has already sent message %s", contact_pk, msg_id) 

271 return 

272 new = ContactSent(contact_id=contact_pk, msg_id=msg_id) 

273 session.add(new) 

274 

275 @staticmethod 

276 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]: 

277 result = [] 

278 to_del = [] 

279 for row in session.execute( 

280 select(ContactSent) 

281 .where(ContactSent.contact_id == contact_pk) 

282 .order_by(ContactSent.id) 

283 ).scalars(): 

284 to_del.append(row.id) 

285 result.append(row.msg_id) 

286 if row.msg_id == msg_id: 

287 break 

288 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del))) 

289 return result 

290 

291 

292class MAMStore: 

293 def __init__(self, session: Session, session_maker: sessionmaker[Any]) -> None: 

294 self.session = session_maker 

295 self.reset_source(session) 

296 

297 @staticmethod 

298 def reset_source(session: Session) -> None: 

299 session.execute( 

300 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL) 

301 ) 

302 

303 @staticmethod 

304 def nuke_older_than(session: Session, days: int) -> None: 

305 session.execute( 

306 delete(ArchivedMessage).where( 

307 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days) 

308 ) 

309 ) 

310 

311 @staticmethod 

312 def add_message( 

313 session: Session, 

314 room_pk: int, 

315 message: HistoryMessage, 

316 archive_only: bool, 

317 legacy_msg_id: str | None, 

318 ) -> None: 

319 source = ( 

320 ArchivedMessageSource.BACKFILL 

321 if archive_only 

322 else ArchivedMessageSource.LIVE 

323 ) 

324 existing = session.execute( 

325 select(ArchivedMessage) 

326 .where(ArchivedMessage.room_id == room_pk) 

327 .where(ArchivedMessage.stanza_id == message.id) 

328 ).scalar() 

329 if existing is None and legacy_msg_id is not None: 

330 existing = session.execute( 

331 select(ArchivedMessage) 

332 .where(ArchivedMessage.room_id == room_pk) 

333 .where(ArchivedMessage.legacy_id == str(legacy_msg_id)) 

334 ).scalar() 

335 if existing is not None: 

336 log.debug("Updating message %s in room %s", message.id, room_pk) 

337 existing.timestamp = message.when 

338 existing.stanza = str(message.stanza) 

339 existing.author_jid = message.stanza.get_from() 

340 existing.source = source 

341 existing.legacy_id = legacy_msg_id 

342 session.add(existing) 

343 return 

344 mam_msg = ArchivedMessage( 

345 stanza_id=message.id, 

346 timestamp=message.when, 

347 stanza=str(message.stanza), 

348 author_jid=message.stanza.get_from(), 

349 room_id=room_pk, 

350 source=source, 

351 legacy_id=legacy_msg_id, 

352 ) 

353 session.add(mam_msg) 

354 

355 @staticmethod 

356 def get_messages( 

357 session: Session, 

358 room_pk: int, 

359 start_date: datetime | None = None, 

360 end_date: datetime | None = None, 

361 before_id: str | None = None, 

362 after_id: str | None = None, 

363 ids: Collection[str] = (), 

364 last_page_n: int | None = None, 

365 sender: str | None = None, 

366 flip: bool = False, 

367 ) -> Iterator[HistoryMessage]: 

368 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

369 if start_date is not None: 

370 q = q.where(ArchivedMessage.timestamp >= start_date) 

371 if end_date is not None: 

372 q = q.where(ArchivedMessage.timestamp <= end_date) 

373 if before_id is not None: 

374 stamp = session.execute( 

375 select(ArchivedMessage.timestamp).where( 

376 ArchivedMessage.stanza_id == before_id, 

377 ArchivedMessage.room_id == room_pk, 

378 ) 

379 ).scalar_one_or_none() 

380 if stamp is None: 

381 raise XMPPError( 

382 "item-not-found", 

383 f"Message {before_id} not found", 

384 ) 

385 q = q.where(ArchivedMessage.timestamp < stamp) 

386 if after_id is not None: 

387 stamp = session.execute( 

388 select(ArchivedMessage.timestamp).where( 

389 ArchivedMessage.stanza_id == after_id, 

390 ArchivedMessage.room_id == room_pk, 

391 ) 

392 ).scalar_one_or_none() 

393 if stamp is None: 

394 raise XMPPError( 

395 "item-not-found", 

396 f"Message {after_id} not found", 

397 ) 

398 q = q.where(ArchivedMessage.timestamp > stamp) 

399 if ids: 

400 q = q.filter(ArchivedMessage.stanza_id.in_(ids)) 

401 if sender is not None: 

402 q = q.where(ArchivedMessage.author_jid == sender) 

403 if flip: 

404 q = q.order_by(ArchivedMessage.timestamp.desc()) 

405 else: 

406 q = q.order_by(ArchivedMessage.timestamp.asc()) 

407 msgs = list(session.execute(q).scalars()) 

408 if ids and len(msgs) != len(ids): 

409 raise XMPPError( 

410 "item-not-found", 

411 "One of the requested messages IDs could not be found " 

412 "with the given constraints.", 

413 ) 

414 if last_page_n is not None: 

415 msgs = msgs[:last_page_n] if flip else msgs[-last_page_n:] 

416 for h in msgs: 

417 yield HistoryMessage( 

418 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=UTC) 

419 ) 

420 

421 @staticmethod 

422 def get_first( 

423 session: Session, room_pk: int, with_legacy_id: bool = False 

424 ) -> ArchivedMessage | None: 

425 q = ( 

426 select(ArchivedMessage) 

427 .where(ArchivedMessage.room_id == room_pk) 

428 .order_by(ArchivedMessage.timestamp.asc()) 

429 ) 

430 if with_legacy_id: 

431 q = q.filter(ArchivedMessage.legacy_id.isnot(None)) 

432 return session.execute(q).scalar() 

433 

434 @staticmethod 

435 def get_last( 

436 session: Session, room_pk: int, source: ArchivedMessageSource | None = None 

437 ) -> ArchivedMessage | None: 

438 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

439 

440 if source is not None: 

441 q = q.where(ArchivedMessage.source == source) 

442 

443 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

444 

445 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]: 

446 r = [] 

447 first = self.get_first(session, room_pk) 

448 if first is not None: 

449 r.append(MamMetadata(first.stanza_id, first.timestamp)) 

450 last = self.get_last(session, room_pk) 

451 if last is not None: 

452 r.append(MamMetadata(last.stanza_id, last.timestamp)) 

453 return r 

454 

455 @staticmethod 

456 def get_most_recent_with_legacy_id( 

457 session: Session, room_pk: int, source: ArchivedMessageSource | None = None 

458 ) -> ArchivedMessage | None: 

459 q = ( 

460 select(ArchivedMessage) 

461 .where(ArchivedMessage.room_id == room_pk) 

462 .where(ArchivedMessage.legacy_id.isnot(None)) 

463 ) 

464 if source is not None: 

465 q = q.where(ArchivedMessage.source == source) 

466 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

467 

468 @staticmethod 

469 def get_least_recent_with_legacy_id_after( 

470 session: Session, 

471 room_pk: int, 

472 after_id: str, 

473 source: ArchivedMessageSource = ArchivedMessageSource.LIVE, 

474 ) -> ArchivedMessage | None: 

475 after_timestamp = ( 

476 session.query(ArchivedMessage.timestamp) 

477 .filter(ArchivedMessage.room_id == room_pk) 

478 .filter(ArchivedMessage.legacy_id == after_id) 

479 .scalar() 

480 ) 

481 q = ( 

482 select(ArchivedMessage) 

483 .where(ArchivedMessage.room_id == room_pk) 

484 .where(ArchivedMessage.legacy_id.isnot(None)) 

485 .where(ArchivedMessage.source == source) 

486 .where(ArchivedMessage.timestamp > after_timestamp) 

487 ) 

488 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar() 

489 

490 @staticmethod 

491 def get_by_legacy_id( 

492 session: Session, room_pk: int, legacy_id: str 

493 ) -> ArchivedMessage | None: 

494 return ( 

495 session.query(ArchivedMessage) 

496 .filter(ArchivedMessage.room_id == room_pk) 

497 .filter(ArchivedMessage.legacy_id == legacy_id) 

498 .first() 

499 ) 

500 

501 @staticmethod 

502 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]: 

503 q = ( 

504 select(ArchivedMessage.id, ArchivedMessage.stanza_id) 

505 .where(ArchivedMessage.room_id == room_pk) 

506 .where(~ArchivedMessage.displayed_by_user) 

507 .where(ArchivedMessage.legacy_id.is_not(None)) 

508 .order_by(ArchivedMessage.timestamp.asc()) 

509 ) 

510 

511 ref = session.scalar( 

512 select(ArchivedMessage) 

513 .where(ArchivedMessage.room_id == room_pk) 

514 .where(ArchivedMessage.stanza_id == stanza_id) 

515 ) 

516 

517 if ref is None: 

518 log.debug( 

519 "(pop unread in muc): message not found, returning all MAM messages." 

520 ) 

521 rows = session.execute(q) 

522 else: 

523 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp)) 

524 

525 pks: list[int] = [] 

526 stanza_ids: list[str] = [] 

527 

528 for id_, stanza_id in rows: 

529 pks.append(id_) 

530 stanza_ids.append(stanza_id) 

531 

532 session.execute( 

533 update(ArchivedMessage) 

534 .where(ArchivedMessage.id.in_(pks)) 

535 .values(displayed_by_user=True) 

536 ) 

537 return stanza_ids 

538 

539 @staticmethod 

540 def is_displayed_by_user( 

541 session: Session, room_jid: str, legacy_msg_id: str 

542 ) -> bool: 

543 return any( 

544 session.execute( 

545 select(ArchivedMessage.displayed_by_user) 

546 .join(Room) 

547 .where(Room.jid == room_jid) 

548 .where(ArchivedMessage.legacy_id == legacy_msg_id) 

549 ).scalars() 

550 ) 

551 

552 

553class RoomStore(UpdatedMixin): 

554 model = Room 

555 

556 def reset_updated(self, session: Session) -> None: 

557 super().reset_updated(session) 

558 session.execute( 

559 update(Room).values( 

560 subject_setter=None, 

561 user_resources=None, 

562 history_filled=False, 

563 participants_filled=False, 

564 ) 

565 ) 

566 

567 @staticmethod 

568 def get_all(session: Session, user_pk: int) -> Iterator[Room]: 

569 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk)) 

570 

571 @staticmethod 

572 def get(session: Session, user_pk: int, legacy_id: str) -> Room: 

573 return session.execute( 

574 select(Room) 

575 .where(Room.user_account_id == user_pk) 

576 .where(Room.legacy_id == legacy_id) 

577 ).scalar_one() 

578 

579 @staticmethod 

580 def nick_available(session: Session, room_pk: int, nickname: str) -> bool: 

581 return ( 

582 session.execute( 

583 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname) 

584 ) 

585 ).one_or_none() is None 

586 

587 

588class ParticipantStore: 

589 def __init__(self, session: Session) -> None: 

590 session.execute(delete(Participant)) 

591 

592 @staticmethod 

593 def get_all( 

594 session: Session, room_pk: int, user_included: bool = True 

595 ) -> Iterator[Participant]: 

596 query = select(Participant).where(Participant.room_id == room_pk) 

597 if not user_included: 

598 query = query.where(~Participant.is_user) 

599 yield from session.scalars(query).unique() 

600 

601 @staticmethod 

602 def delete(session: Session, pk: int) -> None: 

603 session.execute(delete(Participant).where(Participant.id == pk)) 

604 

605 

606class BobStore: 

607 _ATTR_MAP: ClassVar[dict[str, str]] = { 

608 "sha-1": "sha_1", 

609 "sha1": "sha_1", 

610 "sha-256": "sha_256", 

611 "sha256": "sha_256", 

612 "sha-512": "sha_512", 

613 "sha512": "sha_512", 

614 } 

615 

616 _ALG_MAP: ClassVar[dict[str, Callable[[bytes], hashlib._Hash]]] = { 

617 "sha_1": hashlib.sha1, 

618 "sha_256": hashlib.sha256, 

619 "sha_512": hashlib.sha512, 

620 } 

621 

622 def __init__(self) -> None: 

623 if (config.HOME_DIR / "slidge_stickers").exists(): 

624 shutil.move( 

625 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store" 

626 ) 

627 self.root_dir = config.HOME_DIR / "bob_store" 

628 self.root_dir.mkdir(exist_ok=True) 

629 

630 @staticmethod 

631 def __split_cid(cid: str) -> list[str]: 

632 return cid.removesuffix("@bob.xmpp.org").split("+") 

633 

634 def __get_condition(self, cid: str) -> ColumnElement[bool]: 

635 alg_name, digest = self.__split_cid(cid) 

636 attr = self._ATTR_MAP.get(alg_name) 

637 if attr is None: 

638 log.warning("Unknown hash algorithm: %s", alg_name) 

639 raise ValueError 

640 return getattr(Bob, attr) == digest # type:ignore[no-any-return] 

641 

642 def get(self, session: Session, cid: str) -> Bob | None: 

643 try: 

644 return session.query(Bob).filter(self.__get_condition(cid)).scalar() # type:ignore[no-any-return] 

645 except ValueError: 

646 log.warning("Cannot get Bob with CID: %s", cid) 

647 return None 

648 

649 def get_sticker(self, session: Session, cid: str) -> Sticker | None: 

650 bob = self.get(session, cid) 

651 if bob is None: 

652 return None 

653 return self.__sticker_from_bob(bob) 

654 

655 def __sticker_from_bob(self, bob: Bob) -> Sticker: 

656 return Sticker( 

657 self.root_dir / bob.file_name, 

658 bob.content_type, 

659 {h: getattr(bob, h) for h in self._ALG_MAP}, 

660 ) 

661 

662 def get_bob( 

663 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str 

664 ) -> BitsOfBinary | None: 

665 stored = self.get(session, cid) 

666 if stored is None: 

667 return None 

668 bob = BitsOfBinary() 

669 bob["data"] = (self.root_dir / stored.file_name).read_bytes() 

670 if stored.content_type is not None: 

671 bob["type"] = stored.content_type 

672 bob["cid"] = cid 

673 return bob 

674 

675 def del_bob( 

676 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str 

677 ) -> None: 

678 try: 

679 file_name = session.scalar( 

680 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name) 

681 ) 

682 except ValueError: 

683 log.warning("Cannot delete Bob with CID: %s", cid) 

684 return None 

685 if file_name is None: 

686 log.warning("No BoB with CID: %s", cid) 

687 return None 

688 (self.root_dir / file_name).unlink() 

689 

690 def set_bob( 

691 self, 

692 session: Session, 

693 _jid: object, 

694 _node: object, 

695 _ifrom: object, 

696 bob: BitsOfBinary, 

697 ) -> Sticker | None: 

698 return self.set_sticker(session, bob["cid"], bob["data"], bob["type"]) 

699 

700 def set_sticker( 

701 self, 

702 session: Session, 

703 cid: str, 

704 bytes_: bytes, 

705 content_type: str | None, 

706 ) -> Sticker | None: 

707 try: 

708 alg_name, digest = self.__split_cid(cid) 

709 except ValueError: 

710 log.warning("Invalid CID provided: %s", cid) 

711 return None 

712 attr = self._ATTR_MAP.get(alg_name) 

713 if attr is None: 

714 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name) 

715 return None 

716 existing = self.get(session, cid) 

717 if existing: 

718 log.debug("Bob already exists") 

719 return None 

720 path = self.root_dir / uuid.uuid4().hex 

721 if content_type is None: 

722 try: 

723 import magic 

724 except ImportError: 

725 content_type = "application/octet-stream" 

726 else: 

727 content_type = magic.from_buffer(bytes_, mime=True) 

728 path = path.with_suffix(guess_extension(content_type) or "") 

729 path.write_bytes(bytes_) 

730 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()} 

731 if hashes[attr] != digest: 

732 path.unlink(missing_ok=True) 

733 raise ValueError("Provided CID does not match calculated hash") 

734 row = Bob(file_name=path.name, content_type=content_type, **hashes) 

735 session.add(row) 

736 return self.__sticker_from_bob(row) 

737 

738 

739class SpaceStore(UpdatedMixin): 

740 model = Space 

741 

742 def __init__(self, session: Session) -> None: 

743 session.execute(delete(Space)) 

744 

745 @staticmethod 

746 def add_or_get(session: Session, user_pk: int, legacy_id: str) -> Space: 

747 space = session.execute( 

748 select(Space) 

749 .where(Space.user_account_id == user_pk) 

750 .where(Space.legacy_id == legacy_id) 

751 ).scalar_one_or_none() 

752 if space is None: 

753 space = Space(user_account_id=user_pk, legacy_id=legacy_id) 

754 session.add(space) 

755 session.commit() 

756 return space 

757 

758 @staticmethod 

759 def get_all(session: Session, user_pk: int) -> Iterable[Space]: 

760 return session.execute( 

761 select(Space).where(Space.user_account_id == user_pk) 

762 ).scalars() 

763 

764 @staticmethod 

765 def get_by_legacy_id( 

766 session: Session, user_pk: int, legacy_id: str, full: bool = False 

767 ) -> Space | None: 

768 stmt = ( 

769 select(Space) 

770 .where(Space.user_account_id == user_pk) 

771 .where(Space.legacy_id == legacy_id) 

772 ) 

773 if full: 

774 stmt = stmt.options( 

775 joinedload(Space.owners), 

776 joinedload(Space.creator), 

777 ) 

778 return session.execute(stmt).unique().scalar_one_or_none() 

779 

780 @staticmethod 

781 def get_unupdated(session: Session, user_pk: int) -> list[Space]: 

782 return list( 

783 session.execute( 

784 select(Space) 

785 .where(Space.user_account_id == user_pk) 

786 .where(Space.updated.is_(False)) 

787 ).scalars() 

788 ) 

789 

790 @staticmethod 

791 def get_rooms( 

792 session: Session, 

793 user_pk: int, 

794 legacy_id: str, 

795 room_legacy_ids: Iterable[str] = (), 

796 ) -> list[Room]: 

797 q = ( 

798 select(Room) 

799 .join(Room.space) 

800 .where(Room.user_account_id == user_pk) 

801 .where(Space.legacy_id == legacy_id) 

802 .options(load_only(Room.jid, Room.name)) 

803 ) 

804 if room_legacy_ids: 

805 q = q.where(Room.legacy_id.in_(room_legacy_ids)) 

806 return list(session.execute(q).scalars()) 

807 

808 @staticmethod 

809 def exists(session: Session, user_pk: int, legacy_id: str) -> bool: 

810 return session.execute( 

811 select( 

812 sa.exists() 

813 .where(Space.user_account_id == user_pk) 

814 .where(Space.legacy_id == legacy_id) 

815 ) 

816 ).scalar_one() 

817 

818 

819@event.listens_for(sa.orm.Session, "after_flush") 

820def _check_avatar_orphans(session: Session, flush_context: sa.ExecutionContext) -> None: 

821 if not session.deleted: 

822 return 

823 

824 potentially_orphaned = set() 

825 for obj in session.deleted: 

826 if isinstance(obj, (Contact, Room)) and obj.avatar_id: 

827 potentially_orphaned.add(obj.avatar_id) 

828 if not potentially_orphaned: 

829 return 

830 

831 result = session.execute( 

832 sa.delete(Avatar).where( 

833 sa.and_( 

834 Avatar.id.in_(potentially_orphaned), 

835 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)), 

836 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)), 

837 ) 

838 ) 

839 ) 

840 deleted_count = result.rowcount # type:ignore[attr-defined] 

841 log.debug("Auto-deleted %s orphaned avatars", deleted_count) 

842 

843 

844log = logging.getLogger(__name__)