Coverage for slidge / db / store.py: 90%
388 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
1from __future__ import annotations
3import hashlib
4import logging
5import shutil
6import uuid
7from collections.abc import Callable, Collection, Iterable, Iterator
8from datetime import UTC, datetime, timedelta
9from mimetypes import guess_extension
10from typing import Any, ClassVar
12import sqlalchemy as sa
13from slixmpp.exceptions import XMPPError
14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
15from sqlalchemy import ColumnElement, Engine, delete, event, select, update
16from sqlalchemy.exc import InvalidRequestError
17from sqlalchemy.orm import Session, attributes, joinedload, load_only, sessionmaker
19from ..core import config
20from ..util.archive_msg import HistoryMessage
21from ..util.types import MamMetadata, Sticker
22from .meta import Base
23from .models import (
24 ArchivedMessage,
25 ArchivedMessageSource,
26 Avatar,
27 Bob,
28 Contact,
29 ContactSent,
30 DirectMessages,
31 DirectThreads,
32 GatewayUser,
33 GroupMessages,
34 GroupMessagesOrigin,
35 GroupThreads,
36 Participant,
37 Room,
38 Space,
39)
42class UpdatedMixin:
43 model: type[Base] = NotImplemented
45 def __init__(self, session: Session) -> None:
46 self.reset_updated(session)
48 def get_by_pk(self, session: Session, pk: int) -> type[Base]:
49 stmt = select(self.model).where(self.model.id == pk) # type:ignore[attr-defined]
50 return session.scalar(stmt) # type:ignore[no-any-return]
52 def reset_updated(self, session: Session) -> None:
53 session.execute(update(self.model).values(updated=False))
56class SlidgeStore:
57 def __init__(self, engine: Engine) -> None:
58 self._engine = engine
59 self.session = sessionmaker[Any](engine)
61 self.users = UserStore(self.session)
62 self.avatars = AvatarStore(self.session)
63 self.id_map = IdMapStore()
64 self.bob = BobStore()
65 with self.session() as session:
66 self.contacts = ContactStore(session)
67 self.mam = MAMStore(session, self.session)
68 self.rooms = RoomStore(session)
69 self.participants = ParticipantStore(session)
70 self.spaces = SpaceStore(session)
71 session.commit()
74class UserStore:
75 def __init__(self, session_maker: sessionmaker[Any]) -> None:
76 self.session = session_maker
78 def update(self, user: GatewayUser) -> None:
79 with self.session(expire_on_commit=False) as session:
80 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
81 try:
82 attributes.flag_modified(user, "legacy_module_data")
83 attributes.flag_modified(user, "preferences")
84 except InvalidRequestError:
85 pass
86 session.add(user)
87 session.commit()
90class AvatarStore:
91 def __init__(self, session_maker: sessionmaker[Any]) -> None:
92 self.session = session_maker
95LegacyToXmppType = (
96 type[DirectMessages]
97 | type[DirectThreads]
98 | type[GroupMessages]
99 | type[GroupThreads]
100 | type[GroupMessagesOrigin]
101)
104class IdMapStore:
105 @staticmethod
106 def _set(
107 session: Session,
108 foreign_key: int,
109 legacy_id: str,
110 xmpp_ids: list[str],
111 type_: LegacyToXmppType,
112 ) -> None:
113 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
114 ids = session.scalars(
115 select(type_.id).filter(
116 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
117 )
118 )
119 if ids:
120 log.debug("Resetting legacy ID %s", legacy_id)
121 session.execute(delete(type_).where(type_.id.in_(ids)))
122 for xmpp_id in xmpp_ids:
123 msg = type_(xmpp_id=xmpp_id, **kwargs)
124 session.add(msg)
126 def set_thread(
127 self,
128 session: Session,
129 foreign_key: int,
130 legacy_id: str,
131 xmpp_id: str,
132 group: bool,
133 ) -> None:
134 self._set(
135 session,
136 foreign_key,
137 legacy_id,
138 [xmpp_id],
139 GroupThreads if group else DirectThreads,
140 )
142 def set_msg(
143 self,
144 session: Session,
145 foreign_key: int,
146 legacy_id: str,
147 xmpp_ids: list[str],
148 group: bool,
149 ) -> None:
150 self._set(
151 session,
152 foreign_key,
153 legacy_id,
154 xmpp_ids,
155 GroupMessages if group else DirectMessages,
156 )
158 def set_origin(
159 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str
160 ) -> None:
161 self._set(
162 session,
163 foreign_key,
164 legacy_id,
165 [xmpp_id],
166 GroupMessagesOrigin,
167 )
169 def get_origin(
170 self, session: Session, foreign_key: int, legacy_id: str
171 ) -> list[str]:
172 return self._get(
173 session,
174 foreign_key,
175 legacy_id,
176 GroupMessagesOrigin,
177 )
179 @staticmethod
180 def _get(
181 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
182 ) -> list[str]:
183 return list(
184 session.scalars(
185 select(type_.xmpp_id).filter_by(
186 foreign_key=foreign_key, legacy_id=str(legacy_id)
187 )
188 )
189 )
191 def get_xmpp(
192 self, session: Session, foreign_key: int, legacy_id: str, group: bool
193 ) -> list[str]:
194 return self._get(
195 session,
196 foreign_key,
197 legacy_id,
198 GroupMessages if group else DirectMessages,
199 )
201 @staticmethod
202 def _get_legacy(
203 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
204 ) -> str | None:
205 return session.scalar(
206 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
207 )
209 def get_legacy(
210 self,
211 session: Session,
212 foreign_key: int,
213 xmpp_id: str,
214 group: bool,
215 origin: bool = False,
216 ) -> str | None:
217 if origin and group:
218 return self._get_legacy(
219 session,
220 foreign_key,
221 xmpp_id,
222 GroupMessagesOrigin,
223 )
224 return self._get_legacy(
225 session,
226 foreign_key,
227 xmpp_id,
228 GroupMessages if group else DirectMessages,
229 )
231 def get_thread(
232 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
233 ) -> str | None:
234 return self._get_legacy(
235 session,
236 foreign_key,
237 xmpp_id,
238 GroupThreads if group else DirectThreads,
239 )
241 @staticmethod
242 def was_sent_by_user(
243 session: Session, foreign_key: int, legacy_id: str, group: bool
244 ) -> bool:
245 type_ = GroupMessages if group else DirectMessages
246 return (
247 session.scalar(
248 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
249 )
250 is not None
251 )
254class ContactStore(UpdatedMixin):
255 model = Contact
257 def __init__(self, session: Session) -> None:
258 super().__init__(session)
259 session.execute(update(Contact).values(cached_presence=False))
260 session.execute(update(Contact).values(caps_ver=None))
262 @staticmethod
263 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
264 if (
265 session.query(ContactSent.id)
266 .where(ContactSent.contact_id == contact_pk)
267 .where(ContactSent.msg_id == msg_id)
268 .first()
269 ) is not None:
270 log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
271 return
272 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
273 session.add(new)
275 @staticmethod
276 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
277 result = []
278 to_del = []
279 for row in session.execute(
280 select(ContactSent)
281 .where(ContactSent.contact_id == contact_pk)
282 .order_by(ContactSent.id)
283 ).scalars():
284 to_del.append(row.id)
285 result.append(row.msg_id)
286 if row.msg_id == msg_id:
287 break
288 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
289 return result
292class MAMStore:
293 def __init__(self, session: Session, session_maker: sessionmaker[Any]) -> None:
294 self.session = session_maker
295 self.reset_source(session)
297 @staticmethod
298 def reset_source(session: Session) -> None:
299 session.execute(
300 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
301 )
303 @staticmethod
304 def nuke_older_than(session: Session, days: int) -> None:
305 session.execute(
306 delete(ArchivedMessage).where(
307 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
308 )
309 )
311 @staticmethod
312 def add_message(
313 session: Session,
314 room_pk: int,
315 message: HistoryMessage,
316 archive_only: bool,
317 legacy_msg_id: str | None,
318 ) -> None:
319 source = (
320 ArchivedMessageSource.BACKFILL
321 if archive_only
322 else ArchivedMessageSource.LIVE
323 )
324 existing = session.execute(
325 select(ArchivedMessage)
326 .where(ArchivedMessage.room_id == room_pk)
327 .where(ArchivedMessage.stanza_id == message.id)
328 ).scalar()
329 if existing is None and legacy_msg_id is not None:
330 existing = session.execute(
331 select(ArchivedMessage)
332 .where(ArchivedMessage.room_id == room_pk)
333 .where(ArchivedMessage.legacy_id == str(legacy_msg_id))
334 ).scalar()
335 if existing is not None:
336 log.debug("Updating message %s in room %s", message.id, room_pk)
337 existing.timestamp = message.when
338 existing.stanza = str(message.stanza)
339 existing.author_jid = message.stanza.get_from()
340 existing.source = source
341 existing.legacy_id = legacy_msg_id
342 session.add(existing)
343 return
344 mam_msg = ArchivedMessage(
345 stanza_id=message.id,
346 timestamp=message.when,
347 stanza=str(message.stanza),
348 author_jid=message.stanza.get_from(),
349 room_id=room_pk,
350 source=source,
351 legacy_id=legacy_msg_id,
352 )
353 session.add(mam_msg)
355 @staticmethod
356 def get_messages(
357 session: Session,
358 room_pk: int,
359 start_date: datetime | None = None,
360 end_date: datetime | None = None,
361 before_id: str | None = None,
362 after_id: str | None = None,
363 ids: Collection[str] = (),
364 last_page_n: int | None = None,
365 sender: str | None = None,
366 flip: bool = False,
367 ) -> Iterator[HistoryMessage]:
368 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
369 if start_date is not None:
370 q = q.where(ArchivedMessage.timestamp >= start_date)
371 if end_date is not None:
372 q = q.where(ArchivedMessage.timestamp <= end_date)
373 if before_id is not None:
374 stamp = session.execute(
375 select(ArchivedMessage.timestamp).where(
376 ArchivedMessage.stanza_id == before_id,
377 ArchivedMessage.room_id == room_pk,
378 )
379 ).scalar_one_or_none()
380 if stamp is None:
381 raise XMPPError(
382 "item-not-found",
383 f"Message {before_id} not found",
384 )
385 q = q.where(ArchivedMessage.timestamp < stamp)
386 if after_id is not None:
387 stamp = session.execute(
388 select(ArchivedMessage.timestamp).where(
389 ArchivedMessage.stanza_id == after_id,
390 ArchivedMessage.room_id == room_pk,
391 )
392 ).scalar_one_or_none()
393 if stamp is None:
394 raise XMPPError(
395 "item-not-found",
396 f"Message {after_id} not found",
397 )
398 q = q.where(ArchivedMessage.timestamp > stamp)
399 if ids:
400 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
401 if sender is not None:
402 q = q.where(ArchivedMessage.author_jid == sender)
403 if flip:
404 q = q.order_by(ArchivedMessage.timestamp.desc())
405 else:
406 q = q.order_by(ArchivedMessage.timestamp.asc())
407 msgs = list(session.execute(q).scalars())
408 if ids and len(msgs) != len(ids):
409 raise XMPPError(
410 "item-not-found",
411 "One of the requested messages IDs could not be found "
412 "with the given constraints.",
413 )
414 if last_page_n is not None:
415 msgs = msgs[:last_page_n] if flip else msgs[-last_page_n:]
416 for h in msgs:
417 yield HistoryMessage(
418 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=UTC)
419 )
421 @staticmethod
422 def get_first(
423 session: Session, room_pk: int, with_legacy_id: bool = False
424 ) -> ArchivedMessage | None:
425 q = (
426 select(ArchivedMessage)
427 .where(ArchivedMessage.room_id == room_pk)
428 .order_by(ArchivedMessage.timestamp.asc())
429 )
430 if with_legacy_id:
431 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
432 return session.execute(q).scalar()
434 @staticmethod
435 def get_last(
436 session: Session, room_pk: int, source: ArchivedMessageSource | None = None
437 ) -> ArchivedMessage | None:
438 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
440 if source is not None:
441 q = q.where(ArchivedMessage.source == source)
443 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
445 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
446 r = []
447 first = self.get_first(session, room_pk)
448 if first is not None:
449 r.append(MamMetadata(first.stanza_id, first.timestamp))
450 last = self.get_last(session, room_pk)
451 if last is not None:
452 r.append(MamMetadata(last.stanza_id, last.timestamp))
453 return r
455 @staticmethod
456 def get_most_recent_with_legacy_id(
457 session: Session, room_pk: int, source: ArchivedMessageSource | None = None
458 ) -> ArchivedMessage | None:
459 q = (
460 select(ArchivedMessage)
461 .where(ArchivedMessage.room_id == room_pk)
462 .where(ArchivedMessage.legacy_id.isnot(None))
463 )
464 if source is not None:
465 q = q.where(ArchivedMessage.source == source)
466 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
468 @staticmethod
469 def get_least_recent_with_legacy_id_after(
470 session: Session,
471 room_pk: int,
472 after_id: str,
473 source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
474 ) -> ArchivedMessage | None:
475 after_timestamp = (
476 session.query(ArchivedMessage.timestamp)
477 .filter(ArchivedMessage.room_id == room_pk)
478 .filter(ArchivedMessage.legacy_id == after_id)
479 .scalar()
480 )
481 q = (
482 select(ArchivedMessage)
483 .where(ArchivedMessage.room_id == room_pk)
484 .where(ArchivedMessage.legacy_id.isnot(None))
485 .where(ArchivedMessage.source == source)
486 .where(ArchivedMessage.timestamp > after_timestamp)
487 )
488 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
490 @staticmethod
491 def get_by_legacy_id(
492 session: Session, room_pk: int, legacy_id: str
493 ) -> ArchivedMessage | None:
494 return (
495 session.query(ArchivedMessage)
496 .filter(ArchivedMessage.room_id == room_pk)
497 .filter(ArchivedMessage.legacy_id == legacy_id)
498 .first()
499 )
501 @staticmethod
502 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]:
503 q = (
504 select(ArchivedMessage.id, ArchivedMessage.stanza_id)
505 .where(ArchivedMessage.room_id == room_pk)
506 .where(~ArchivedMessage.displayed_by_user)
507 .where(ArchivedMessage.legacy_id.is_not(None))
508 .order_by(ArchivedMessage.timestamp.asc())
509 )
511 ref = session.scalar(
512 select(ArchivedMessage)
513 .where(ArchivedMessage.room_id == room_pk)
514 .where(ArchivedMessage.stanza_id == stanza_id)
515 )
517 if ref is None:
518 log.debug(
519 "(pop unread in muc): message not found, returning all MAM messages."
520 )
521 rows = session.execute(q)
522 else:
523 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp))
525 pks: list[int] = []
526 stanza_ids: list[str] = []
528 for id_, stanza_id in rows:
529 pks.append(id_)
530 stanza_ids.append(stanza_id)
532 session.execute(
533 update(ArchivedMessage)
534 .where(ArchivedMessage.id.in_(pks))
535 .values(displayed_by_user=True)
536 )
537 return stanza_ids
539 @staticmethod
540 def is_displayed_by_user(
541 session: Session, room_jid: str, legacy_msg_id: str
542 ) -> bool:
543 return any(
544 session.execute(
545 select(ArchivedMessage.displayed_by_user)
546 .join(Room)
547 .where(Room.jid == room_jid)
548 .where(ArchivedMessage.legacy_id == legacy_msg_id)
549 ).scalars()
550 )
553class RoomStore(UpdatedMixin):
554 model = Room
556 def reset_updated(self, session: Session) -> None:
557 super().reset_updated(session)
558 session.execute(
559 update(Room).values(
560 subject_setter=None,
561 user_resources=None,
562 history_filled=False,
563 participants_filled=False,
564 )
565 )
567 @staticmethod
568 def get_all(session: Session, user_pk: int) -> Iterator[Room]:
569 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
571 @staticmethod
572 def get(session: Session, user_pk: int, legacy_id: str) -> Room:
573 return session.execute(
574 select(Room)
575 .where(Room.user_account_id == user_pk)
576 .where(Room.legacy_id == legacy_id)
577 ).scalar_one()
579 @staticmethod
580 def nick_available(session: Session, room_pk: int, nickname: str) -> bool:
581 return (
582 session.execute(
583 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname)
584 )
585 ).one_or_none() is None
588class ParticipantStore:
589 def __init__(self, session: Session) -> None:
590 session.execute(delete(Participant))
592 @staticmethod
593 def get_all(
594 session: Session, room_pk: int, user_included: bool = True
595 ) -> Iterator[Participant]:
596 query = select(Participant).where(Participant.room_id == room_pk)
597 if not user_included:
598 query = query.where(~Participant.is_user)
599 yield from session.scalars(query).unique()
601 @staticmethod
602 def delete(session: Session, pk: int) -> None:
603 session.execute(delete(Participant).where(Participant.id == pk))
606class BobStore:
607 _ATTR_MAP: ClassVar[dict[str, str]] = {
608 "sha-1": "sha_1",
609 "sha1": "sha_1",
610 "sha-256": "sha_256",
611 "sha256": "sha_256",
612 "sha-512": "sha_512",
613 "sha512": "sha_512",
614 }
616 _ALG_MAP: ClassVar[dict[str, Callable[[bytes], hashlib._Hash]]] = {
617 "sha_1": hashlib.sha1,
618 "sha_256": hashlib.sha256,
619 "sha_512": hashlib.sha512,
620 }
622 def __init__(self) -> None:
623 if (config.HOME_DIR / "slidge_stickers").exists():
624 shutil.move(
625 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store"
626 )
627 self.root_dir = config.HOME_DIR / "bob_store"
628 self.root_dir.mkdir(exist_ok=True)
630 @staticmethod
631 def __split_cid(cid: str) -> list[str]:
632 return cid.removesuffix("@bob.xmpp.org").split("+")
634 def __get_condition(self, cid: str) -> ColumnElement[bool]:
635 alg_name, digest = self.__split_cid(cid)
636 attr = self._ATTR_MAP.get(alg_name)
637 if attr is None:
638 log.warning("Unknown hash algorithm: %s", alg_name)
639 raise ValueError
640 return getattr(Bob, attr) == digest # type:ignore[no-any-return]
642 def get(self, session: Session, cid: str) -> Bob | None:
643 try:
644 return session.query(Bob).filter(self.__get_condition(cid)).scalar() # type:ignore[no-any-return]
645 except ValueError:
646 log.warning("Cannot get Bob with CID: %s", cid)
647 return None
649 def get_sticker(self, session: Session, cid: str) -> Sticker | None:
650 bob = self.get(session, cid)
651 if bob is None:
652 return None
653 return self.__sticker_from_bob(bob)
655 def __sticker_from_bob(self, bob: Bob) -> Sticker:
656 return Sticker(
657 self.root_dir / bob.file_name,
658 bob.content_type,
659 {h: getattr(bob, h) for h in self._ALG_MAP},
660 )
662 def get_bob(
663 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str
664 ) -> BitsOfBinary | None:
665 stored = self.get(session, cid)
666 if stored is None:
667 return None
668 bob = BitsOfBinary()
669 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
670 if stored.content_type is not None:
671 bob["type"] = stored.content_type
672 bob["cid"] = cid
673 return bob
675 def del_bob(
676 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str
677 ) -> None:
678 try:
679 file_name = session.scalar(
680 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
681 )
682 except ValueError:
683 log.warning("Cannot delete Bob with CID: %s", cid)
684 return None
685 if file_name is None:
686 log.warning("No BoB with CID: %s", cid)
687 return None
688 (self.root_dir / file_name).unlink()
690 def set_bob(
691 self,
692 session: Session,
693 _jid: object,
694 _node: object,
695 _ifrom: object,
696 bob: BitsOfBinary,
697 ) -> Sticker | None:
698 return self.set_sticker(session, bob["cid"], bob["data"], bob["type"])
700 def set_sticker(
701 self,
702 session: Session,
703 cid: str,
704 bytes_: bytes,
705 content_type: str | None,
706 ) -> Sticker | None:
707 try:
708 alg_name, digest = self.__split_cid(cid)
709 except ValueError:
710 log.warning("Invalid CID provided: %s", cid)
711 return None
712 attr = self._ATTR_MAP.get(alg_name)
713 if attr is None:
714 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
715 return None
716 existing = self.get(session, cid)
717 if existing:
718 log.debug("Bob already exists")
719 return None
720 path = self.root_dir / uuid.uuid4().hex
721 if content_type is None:
722 try:
723 import magic
724 except ImportError:
725 content_type = "application/octet-stream"
726 else:
727 content_type = magic.from_buffer(bytes_, mime=True)
728 path = path.with_suffix(guess_extension(content_type) or "")
729 path.write_bytes(bytes_)
730 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
731 if hashes[attr] != digest:
732 path.unlink(missing_ok=True)
733 raise ValueError("Provided CID does not match calculated hash")
734 row = Bob(file_name=path.name, content_type=content_type, **hashes)
735 session.add(row)
736 return self.__sticker_from_bob(row)
739class SpaceStore(UpdatedMixin):
740 model = Space
742 def __init__(self, session: Session) -> None:
743 session.execute(delete(Space))
745 @staticmethod
746 def add_or_get(session: Session, user_pk: int, legacy_id: str) -> Space:
747 space = session.execute(
748 select(Space)
749 .where(Space.user_account_id == user_pk)
750 .where(Space.legacy_id == legacy_id)
751 ).scalar_one_or_none()
752 if space is None:
753 space = Space(user_account_id=user_pk, legacy_id=legacy_id)
754 session.add(space)
755 session.commit()
756 return space
758 @staticmethod
759 def get_all(session: Session, user_pk: int) -> Iterable[Space]:
760 return session.execute(
761 select(Space).where(Space.user_account_id == user_pk)
762 ).scalars()
764 @staticmethod
765 def get_by_legacy_id(
766 session: Session, user_pk: int, legacy_id: str, full: bool = False
767 ) -> Space | None:
768 stmt = (
769 select(Space)
770 .where(Space.user_account_id == user_pk)
771 .where(Space.legacy_id == legacy_id)
772 )
773 if full:
774 stmt = stmt.options(
775 joinedload(Space.owners),
776 joinedload(Space.creator),
777 )
778 return session.execute(stmt).unique().scalar_one_or_none()
780 @staticmethod
781 def get_unupdated(session: Session, user_pk: int) -> list[Space]:
782 return list(
783 session.execute(
784 select(Space)
785 .where(Space.user_account_id == user_pk)
786 .where(Space.updated.is_(False))
787 ).scalars()
788 )
790 @staticmethod
791 def get_rooms(
792 session: Session,
793 user_pk: int,
794 legacy_id: str,
795 room_legacy_ids: Iterable[str] = (),
796 ) -> list[Room]:
797 q = (
798 select(Room)
799 .join(Room.space)
800 .where(Room.user_account_id == user_pk)
801 .where(Space.legacy_id == legacy_id)
802 .options(load_only(Room.jid, Room.name))
803 )
804 if room_legacy_ids:
805 q = q.where(Room.legacy_id.in_(room_legacy_ids))
806 return list(session.execute(q).scalars())
808 @staticmethod
809 def exists(session: Session, user_pk: int, legacy_id: str) -> bool:
810 return session.execute(
811 select(
812 sa.exists()
813 .where(Space.user_account_id == user_pk)
814 .where(Space.legacy_id == legacy_id)
815 )
816 ).scalar_one()
819@event.listens_for(sa.orm.Session, "after_flush")
820def _check_avatar_orphans(session: Session, flush_context: sa.ExecutionContext) -> None:
821 if not session.deleted:
822 return
824 potentially_orphaned = set()
825 for obj in session.deleted:
826 if isinstance(obj, (Contact, Room)) and obj.avatar_id:
827 potentially_orphaned.add(obj.avatar_id)
828 if not potentially_orphaned:
829 return
831 result = session.execute(
832 sa.delete(Avatar).where(
833 sa.and_(
834 Avatar.id.in_(potentially_orphaned),
835 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)),
836 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)),
837 )
838 )
839 )
840 deleted_count = result.rowcount # type:ignore[attr-defined]
841 log.debug("Auto-deleted %s orphaned avatars", deleted_count)
844log = logging.getLogger(__name__)