Coverage for slidge / db / store.py: 89%

345 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import logging 

5import shutil 

6import uuid 

7from collections.abc import Callable, Collection, Iterator 

8from datetime import UTC, datetime, timedelta 

9from mimetypes import guess_extension 

10from typing import Any, ClassVar 

11 

12import sqlalchemy as sa 

13from slixmpp.exceptions import XMPPError 

14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary 

15from sqlalchemy import ColumnElement, Engine, delete, event, select, update 

16from sqlalchemy.exc import InvalidRequestError 

17from sqlalchemy.orm import Session, attributes, sessionmaker 

18 

19from ..core import config 

20from ..util.archive_msg import HistoryMessage 

21from ..util.types import MamMetadata, Sticker 

22from .meta import Base 

23from .models import ( 

24 ArchivedMessage, 

25 ArchivedMessageSource, 

26 Avatar, 

27 Bob, 

28 Contact, 

29 ContactSent, 

30 DirectMessages, 

31 DirectThreads, 

32 GatewayUser, 

33 GroupMessages, 

34 GroupMessagesOrigin, 

35 GroupThreads, 

36 Participant, 

37 Room, 

38) 

39 

40 

41class UpdatedMixin: 

42 model: type[Base] = NotImplemented 

43 

44 def __init__(self, session: Session) -> None: 

45 self.reset_updated(session) 

46 

47 def get_by_pk(self, session: Session, pk: int) -> type[Base]: 

48 stmt = select(self.model).where(self.model.id == pk) # type:ignore[attr-defined] 

49 return session.scalar(stmt) # type:ignore[no-any-return] 

50 

51 def reset_updated(self, session: Session) -> None: 

52 session.execute(update(self.model).values(updated=False)) 

53 

54 

55class SlidgeStore: 

56 def __init__(self, engine: Engine) -> None: 

57 self._engine = engine 

58 self.session = sessionmaker[Any](engine) 

59 

60 self.users = UserStore(self.session) 

61 self.avatars = AvatarStore(self.session) 

62 self.id_map = IdMapStore() 

63 self.bob = BobStore() 

64 with self.session() as session: 

65 self.contacts = ContactStore(session) 

66 self.mam = MAMStore(session, self.session) 

67 self.rooms = RoomStore(session) 

68 self.participants = ParticipantStore(session) 

69 session.commit() 

70 

71 

72class UserStore: 

73 def __init__(self, session_maker: sessionmaker[Any]) -> None: 

74 self.session = session_maker 

75 

76 def update(self, user: GatewayUser) -> None: 

77 with self.session(expire_on_commit=False) as session: 

78 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473 

79 try: 

80 attributes.flag_modified(user, "legacy_module_data") 

81 attributes.flag_modified(user, "preferences") 

82 except InvalidRequestError: 

83 pass 

84 session.add(user) 

85 session.commit() 

86 

87 

88class AvatarStore: 

89 def __init__(self, session_maker: sessionmaker[Any]) -> None: 

90 self.session = session_maker 

91 

92 

93LegacyToXmppType = ( 

94 type[DirectMessages] 

95 | type[DirectThreads] 

96 | type[GroupMessages] 

97 | type[GroupThreads] 

98 | type[GroupMessagesOrigin] 

99) 

100 

101 

102class IdMapStore: 

103 @staticmethod 

104 def _set( 

105 session: Session, 

106 foreign_key: int, 

107 legacy_id: str, 

108 xmpp_ids: list[str], 

109 type_: LegacyToXmppType, 

110 ) -> None: 

111 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id) 

112 ids = session.scalars( 

113 select(type_.id).filter( 

114 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id 

115 ) 

116 ) 

117 if ids: 

118 log.debug("Resetting legacy ID %s", legacy_id) 

119 session.execute(delete(type_).where(type_.id.in_(ids))) 

120 for xmpp_id in xmpp_ids: 

121 msg = type_(xmpp_id=xmpp_id, **kwargs) 

122 session.add(msg) 

123 

124 def set_thread( 

125 self, 

126 session: Session, 

127 foreign_key: int, 

128 legacy_id: str, 

129 xmpp_id: str, 

130 group: bool, 

131 ) -> None: 

132 self._set( 

133 session, 

134 foreign_key, 

135 legacy_id, 

136 [xmpp_id], 

137 GroupThreads if group else DirectThreads, 

138 ) 

139 

140 def set_msg( 

141 self, 

142 session: Session, 

143 foreign_key: int, 

144 legacy_id: str, 

145 xmpp_ids: list[str], 

146 group: bool, 

147 ) -> None: 

148 self._set( 

149 session, 

150 foreign_key, 

151 legacy_id, 

152 xmpp_ids, 

153 GroupMessages if group else DirectMessages, 

154 ) 

155 

156 def set_origin( 

157 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str 

158 ) -> None: 

159 self._set( 

160 session, 

161 foreign_key, 

162 legacy_id, 

163 [xmpp_id], 

164 GroupMessagesOrigin, 

165 ) 

166 

167 def get_origin( 

168 self, session: Session, foreign_key: int, legacy_id: str 

169 ) -> list[str]: 

170 return self._get( 

171 session, 

172 foreign_key, 

173 legacy_id, 

174 GroupMessagesOrigin, 

175 ) 

176 

177 @staticmethod 

178 def _get( 

179 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType 

180 ) -> list[str]: 

181 return list( 

182 session.scalars( 

183 select(type_.xmpp_id).filter_by( 

184 foreign_key=foreign_key, legacy_id=str(legacy_id) 

185 ) 

186 ) 

187 ) 

188 

189 def get_xmpp( 

190 self, session: Session, foreign_key: int, legacy_id: str, group: bool 

191 ) -> list[str]: 

192 return self._get( 

193 session, 

194 foreign_key, 

195 legacy_id, 

196 GroupMessages if group else DirectMessages, 

197 ) 

198 

199 @staticmethod 

200 def _get_legacy( 

201 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType 

202 ) -> str | None: 

203 return session.scalar( 

204 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id) 

205 ) 

206 

207 def get_legacy( 

208 self, 

209 session: Session, 

210 foreign_key: int, 

211 xmpp_id: str, 

212 group: bool, 

213 origin: bool = False, 

214 ) -> str | None: 

215 if origin and group: 

216 return self._get_legacy( 

217 session, 

218 foreign_key, 

219 xmpp_id, 

220 GroupMessagesOrigin, 

221 ) 

222 return self._get_legacy( 

223 session, 

224 foreign_key, 

225 xmpp_id, 

226 GroupMessages if group else DirectMessages, 

227 ) 

228 

229 def get_thread( 

230 self, session: Session, foreign_key: int, xmpp_id: str, group: bool 

231 ) -> str | None: 

232 return self._get_legacy( 

233 session, 

234 foreign_key, 

235 xmpp_id, 

236 GroupThreads if group else DirectThreads, 

237 ) 

238 

239 @staticmethod 

240 def was_sent_by_user( 

241 session: Session, foreign_key: int, legacy_id: str, group: bool 

242 ) -> bool: 

243 type_ = GroupMessages if group else DirectMessages 

244 return ( 

245 session.scalar( 

246 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id) 

247 ) 

248 is not None 

249 ) 

250 

251 

252class ContactStore(UpdatedMixin): 

253 model = Contact 

254 

255 def __init__(self, session: Session) -> None: 

256 super().__init__(session) 

257 session.execute(update(Contact).values(cached_presence=False)) 

258 session.execute(update(Contact).values(caps_ver=None)) 

259 

260 @staticmethod 

261 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None: 

262 if ( 

263 session.query(ContactSent.id) 

264 .where(ContactSent.contact_id == contact_pk) 

265 .where(ContactSent.msg_id == msg_id) 

266 .first() 

267 ) is not None: 

268 log.warning("Contact %s has already sent message %s", contact_pk, msg_id) 

269 return 

270 new = ContactSent(contact_id=contact_pk, msg_id=msg_id) 

271 session.add(new) 

272 

273 @staticmethod 

274 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]: 

275 result = [] 

276 to_del = [] 

277 for row in session.execute( 

278 select(ContactSent) 

279 .where(ContactSent.contact_id == contact_pk) 

280 .order_by(ContactSent.id) 

281 ).scalars(): 

282 to_del.append(row.id) 

283 result.append(row.msg_id) 

284 if row.msg_id == msg_id: 

285 break 

286 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del))) 

287 return result 

288 

289 

290class MAMStore: 

291 def __init__(self, session: Session, session_maker: sessionmaker[Any]) -> None: 

292 self.session = session_maker 

293 self.reset_source(session) 

294 

295 @staticmethod 

296 def reset_source(session: Session) -> None: 

297 session.execute( 

298 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL) 

299 ) 

300 

301 @staticmethod 

302 def nuke_older_than(session: Session, days: int) -> None: 

303 session.execute( 

304 delete(ArchivedMessage).where( 

305 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days) 

306 ) 

307 ) 

308 

309 @staticmethod 

310 def add_message( 

311 session: Session, 

312 room_pk: int, 

313 message: HistoryMessage, 

314 archive_only: bool, 

315 legacy_msg_id: str | None, 

316 ) -> None: 

317 source = ( 

318 ArchivedMessageSource.BACKFILL 

319 if archive_only 

320 else ArchivedMessageSource.LIVE 

321 ) 

322 existing = session.execute( 

323 select(ArchivedMessage) 

324 .where(ArchivedMessage.room_id == room_pk) 

325 .where(ArchivedMessage.stanza_id == message.id) 

326 ).scalar() 

327 if existing is None and legacy_msg_id is not None: 

328 existing = session.execute( 

329 select(ArchivedMessage) 

330 .where(ArchivedMessage.room_id == room_pk) 

331 .where(ArchivedMessage.legacy_id == str(legacy_msg_id)) 

332 ).scalar() 

333 if existing is not None: 

334 log.debug("Updating message %s in room %s", message.id, room_pk) 

335 existing.timestamp = message.when 

336 existing.stanza = str(message.stanza) 

337 existing.author_jid = message.stanza.get_from() 

338 existing.source = source 

339 existing.legacy_id = legacy_msg_id 

340 session.add(existing) 

341 return 

342 mam_msg = ArchivedMessage( 

343 stanza_id=message.id, 

344 timestamp=message.when, 

345 stanza=str(message.stanza), 

346 author_jid=message.stanza.get_from(), 

347 room_id=room_pk, 

348 source=source, 

349 legacy_id=legacy_msg_id, 

350 ) 

351 session.add(mam_msg) 

352 

353 @staticmethod 

354 def get_messages( 

355 session: Session, 

356 room_pk: int, 

357 start_date: datetime | None = None, 

358 end_date: datetime | None = None, 

359 before_id: str | None = None, 

360 after_id: str | None = None, 

361 ids: Collection[str] = (), 

362 last_page_n: int | None = None, 

363 sender: str | None = None, 

364 flip: bool = False, 

365 ) -> Iterator[HistoryMessage]: 

366 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

367 if start_date is not None: 

368 q = q.where(ArchivedMessage.timestamp >= start_date) 

369 if end_date is not None: 

370 q = q.where(ArchivedMessage.timestamp <= end_date) 

371 if before_id is not None: 

372 stamp = session.execute( 

373 select(ArchivedMessage.timestamp).where( 

374 ArchivedMessage.stanza_id == before_id, 

375 ArchivedMessage.room_id == room_pk, 

376 ) 

377 ).scalar_one_or_none() 

378 if stamp is None: 

379 raise XMPPError( 

380 "item-not-found", 

381 f"Message {before_id} not found", 

382 ) 

383 q = q.where(ArchivedMessage.timestamp < stamp) 

384 if after_id is not None: 

385 stamp = session.execute( 

386 select(ArchivedMessage.timestamp).where( 

387 ArchivedMessage.stanza_id == after_id, 

388 ArchivedMessage.room_id == room_pk, 

389 ) 

390 ).scalar_one_or_none() 

391 if stamp is None: 

392 raise XMPPError( 

393 "item-not-found", 

394 f"Message {after_id} not found", 

395 ) 

396 q = q.where(ArchivedMessage.timestamp > stamp) 

397 if ids: 

398 q = q.filter(ArchivedMessage.stanza_id.in_(ids)) 

399 if sender is not None: 

400 q = q.where(ArchivedMessage.author_jid == sender) 

401 if flip: 

402 q = q.order_by(ArchivedMessage.timestamp.desc()) 

403 else: 

404 q = q.order_by(ArchivedMessage.timestamp.asc()) 

405 msgs = list(session.execute(q).scalars()) 

406 if ids and len(msgs) != len(ids): 

407 raise XMPPError( 

408 "item-not-found", 

409 "One of the requested messages IDs could not be found " 

410 "with the given constraints.", 

411 ) 

412 if last_page_n is not None: 

413 if flip: 

414 msgs = msgs[:last_page_n] 

415 else: 

416 msgs = msgs[-last_page_n:] 

417 for h in msgs: 

418 yield HistoryMessage( 

419 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=UTC) 

420 ) 

421 

422 @staticmethod 

423 def get_first( 

424 session: Session, room_pk: int, with_legacy_id: bool = False 

425 ) -> ArchivedMessage | None: 

426 q = ( 

427 select(ArchivedMessage) 

428 .where(ArchivedMessage.room_id == room_pk) 

429 .order_by(ArchivedMessage.timestamp.asc()) 

430 ) 

431 if with_legacy_id: 

432 q = q.filter(ArchivedMessage.legacy_id.isnot(None)) 

433 return session.execute(q).scalar() 

434 

435 @staticmethod 

436 def get_last( 

437 session: Session, room_pk: int, source: ArchivedMessageSource | None = None 

438 ) -> ArchivedMessage | None: 

439 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

440 

441 if source is not None: 

442 q = q.where(ArchivedMessage.source == source) 

443 

444 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

445 

446 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]: 

447 r = [] 

448 first = self.get_first(session, room_pk) 

449 if first is not None: 

450 r.append(MamMetadata(first.stanza_id, first.timestamp)) 

451 last = self.get_last(session, room_pk) 

452 if last is not None: 

453 r.append(MamMetadata(last.stanza_id, last.timestamp)) 

454 return r 

455 

456 @staticmethod 

457 def get_most_recent_with_legacy_id( 

458 session: Session, room_pk: int, source: ArchivedMessageSource | None = None 

459 ) -> ArchivedMessage | None: 

460 q = ( 

461 select(ArchivedMessage) 

462 .where(ArchivedMessage.room_id == room_pk) 

463 .where(ArchivedMessage.legacy_id.isnot(None)) 

464 ) 

465 if source is not None: 

466 q = q.where(ArchivedMessage.source == source) 

467 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

468 

469 @staticmethod 

470 def get_least_recent_with_legacy_id_after( 

471 session: Session, 

472 room_pk: int, 

473 after_id: str, 

474 source: ArchivedMessageSource = ArchivedMessageSource.LIVE, 

475 ) -> ArchivedMessage | None: 

476 after_timestamp = ( 

477 session.query(ArchivedMessage.timestamp) 

478 .filter(ArchivedMessage.room_id == room_pk) 

479 .filter(ArchivedMessage.legacy_id == after_id) 

480 .scalar() 

481 ) 

482 q = ( 

483 select(ArchivedMessage) 

484 .where(ArchivedMessage.room_id == room_pk) 

485 .where(ArchivedMessage.legacy_id.isnot(None)) 

486 .where(ArchivedMessage.source == source) 

487 .where(ArchivedMessage.timestamp > after_timestamp) 

488 ) 

489 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar() 

490 

491 @staticmethod 

492 def get_by_legacy_id( 

493 session: Session, room_pk: int, legacy_id: str 

494 ) -> ArchivedMessage | None: 

495 return ( 

496 session.query(ArchivedMessage) 

497 .filter(ArchivedMessage.room_id == room_pk) 

498 .filter(ArchivedMessage.legacy_id == legacy_id) 

499 .first() 

500 ) 

501 

502 @staticmethod 

503 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]: 

504 q = ( 

505 select(ArchivedMessage.id, ArchivedMessage.stanza_id) 

506 .where(ArchivedMessage.room_id == room_pk) 

507 .where(~ArchivedMessage.displayed_by_user) 

508 .where(ArchivedMessage.legacy_id.is_not(None)) 

509 .order_by(ArchivedMessage.timestamp.asc()) 

510 ) 

511 

512 ref = session.scalar( 

513 select(ArchivedMessage) 

514 .where(ArchivedMessage.room_id == room_pk) 

515 .where(ArchivedMessage.stanza_id == stanza_id) 

516 ) 

517 

518 if ref is None: 

519 log.debug( 

520 "(pop unread in muc): message not found, returning all MAM messages." 

521 ) 

522 rows = session.execute(q) 

523 else: 

524 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp)) 

525 

526 pks: list[int] = [] 

527 stanza_ids: list[str] = [] 

528 

529 for id_, stanza_id in rows: 

530 pks.append(id_) 

531 stanza_ids.append(stanza_id) 

532 

533 session.execute( 

534 update(ArchivedMessage) 

535 .where(ArchivedMessage.id.in_(pks)) 

536 .values(displayed_by_user=True) 

537 ) 

538 return stanza_ids 

539 

540 @staticmethod 

541 def is_displayed_by_user( 

542 session: Session, room_jid: str, legacy_msg_id: str 

543 ) -> bool: 

544 return any( 

545 session.execute( 

546 select(ArchivedMessage.displayed_by_user) 

547 .join(Room) 

548 .where(Room.jid == room_jid) 

549 .where(ArchivedMessage.legacy_id == legacy_msg_id) 

550 ).scalars() 

551 ) 

552 

553 

554class RoomStore(UpdatedMixin): 

555 model = Room 

556 

557 def reset_updated(self, session: Session) -> None: 

558 super().reset_updated(session) 

559 session.execute( 

560 update(Room).values( 

561 subject_setter=None, 

562 user_resources=None, 

563 history_filled=False, 

564 participants_filled=False, 

565 ) 

566 ) 

567 

568 @staticmethod 

569 def get_all(session: Session, user_pk: int) -> Iterator[Room]: 

570 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk)) 

571 

572 @staticmethod 

573 def get(session: Session, user_pk: int, legacy_id: str) -> Room: 

574 return session.execute( 

575 select(Room) 

576 .where(Room.user_account_id == user_pk) 

577 .where(Room.legacy_id == legacy_id) 

578 ).scalar_one() 

579 

580 @staticmethod 

581 def nick_available(session: Session, room_pk: int, nickname: str) -> bool: 

582 return ( 

583 session.execute( 

584 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname) 

585 ) 

586 ).one_or_none() is None 

587 

588 

589class ParticipantStore: 

590 def __init__(self, session: Session) -> None: 

591 session.execute(delete(Participant)) 

592 

593 @staticmethod 

594 def get_all( 

595 session: Session, room_pk: int, user_included: bool = True 

596 ) -> Iterator[Participant]: 

597 query = select(Participant).where(Participant.room_id == room_pk) 

598 if not user_included: 

599 query = query.where(~Participant.is_user) 

600 yield from session.scalars(query).unique() 

601 

602 

603class BobStore: 

604 _ATTR_MAP: ClassVar[dict[str, str]] = { 

605 "sha-1": "sha_1", 

606 "sha1": "sha_1", 

607 "sha-256": "sha_256", 

608 "sha256": "sha_256", 

609 "sha-512": "sha_512", 

610 "sha512": "sha_512", 

611 } 

612 

613 _ALG_MAP: ClassVar[dict[str, Callable[[bytes], hashlib._Hash]]] = { 

614 "sha_1": hashlib.sha1, 

615 "sha_256": hashlib.sha256, 

616 "sha_512": hashlib.sha512, 

617 } 

618 

619 def __init__(self) -> None: 

620 if (config.HOME_DIR / "slidge_stickers").exists(): 

621 shutil.move( 

622 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store" 

623 ) 

624 self.root_dir = config.HOME_DIR / "bob_store" 

625 self.root_dir.mkdir(exist_ok=True) 

626 

627 @staticmethod 

628 def __split_cid(cid: str) -> list[str]: 

629 return cid.removesuffix("@bob.xmpp.org").split("+") 

630 

631 def __get_condition(self, cid: str) -> ColumnElement[bool]: 

632 alg_name, digest = self.__split_cid(cid) 

633 attr = self._ATTR_MAP.get(alg_name) 

634 if attr is None: 

635 log.warning("Unknown hash algorithm: %s", alg_name) 

636 raise ValueError 

637 return getattr(Bob, attr) == digest # type:ignore[no-any-return] 

638 

639 def get(self, session: Session, cid: str) -> Bob | None: 

640 try: 

641 return session.query(Bob).filter(self.__get_condition(cid)).scalar() # type:ignore[no-any-return] 

642 except ValueError: 

643 log.warning("Cannot get Bob with CID: %s", cid) 

644 return None 

645 

646 def get_sticker(self, session: Session, cid: str) -> Sticker | None: 

647 bob = self.get(session, cid) 

648 if bob is None: 

649 return None 

650 return Sticker( 

651 self.root_dir / bob.file_name, 

652 bob.content_type, 

653 {h: getattr(bob, h) for h in self._ALG_MAP}, 

654 ) 

655 

656 def get_bob( 

657 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str 

658 ) -> BitsOfBinary | None: 

659 stored = self.get(session, cid) 

660 if stored is None: 

661 return None 

662 bob = BitsOfBinary() 

663 bob["data"] = (self.root_dir / stored.file_name).read_bytes() 

664 if stored.content_type is not None: 

665 bob["type"] = stored.content_type 

666 bob["cid"] = cid 

667 return bob 

668 

669 def del_bob( 

670 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str 

671 ) -> None: 

672 try: 

673 file_name = session.scalar( 

674 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name) 

675 ) 

676 except ValueError: 

677 log.warning("Cannot delete Bob with CID: %s", cid) 

678 return None 

679 if file_name is None: 

680 log.warning("No BoB with CID: %s", cid) 

681 return None 

682 (self.root_dir / file_name).unlink() 

683 

684 def set_bob( 

685 self, 

686 session: Session, 

687 _jid: object, 

688 _node: object, 

689 _ifrom: object, 

690 bob: BitsOfBinary, 

691 ) -> None: 

692 cid = bob["cid"] 

693 try: 

694 alg_name, digest = self.__split_cid(cid) 

695 except ValueError: 

696 log.warning("Invalid CID provided: %s", cid) 

697 return 

698 attr = self._ATTR_MAP.get(alg_name) 

699 if attr is None: 

700 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name) 

701 return 

702 existing = self.get(session, bob["cid"]) 

703 if existing: 

704 log.debug("Bob already exists") 

705 return 

706 bytes_ = bob["data"] 

707 path = self.root_dir / uuid.uuid4().hex 

708 if bob["type"]: 

709 path = path.with_suffix(guess_extension(bob["type"]) or "") 

710 path.write_bytes(bytes_) 

711 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()} 

712 if hashes[attr] != digest: 

713 path.unlink(missing_ok=True) 

714 raise ValueError("Provided CID does not match calculated hash") 

715 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes) 

716 session.add(row) 

717 

718 

719@event.listens_for(sa.orm.Session, "after_flush") 

720def _check_avatar_orphans(session: Session, flush_context: sa.ExecutionContext) -> None: 

721 if not session.deleted: 

722 return 

723 

724 potentially_orphaned = set() 

725 for obj in session.deleted: 

726 if isinstance(obj, (Contact, Room)) and obj.avatar_id: 

727 potentially_orphaned.add(obj.avatar_id) 

728 if not potentially_orphaned: 

729 return 

730 

731 result = session.execute( 

732 sa.delete(Avatar).where( 

733 sa.and_( 

734 Avatar.id.in_(potentially_orphaned), 

735 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)), 

736 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)), 

737 ) 

738 ) 

739 ) 

740 deleted_count = result.rowcount # type:ignore[attr-defined] 

741 log.debug("Auto-deleted %s orphaned avatars", deleted_count) 

742 

743 

744log = logging.getLogger(__name__)