Coverage for slidge / db / store.py: 90%
388 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 04:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 04:38 +0000
1from __future__ import annotations
3import hashlib
4import logging
5import shutil
6import uuid
7from collections.abc import Callable, Collection, Iterable, Iterator
8from datetime import UTC, datetime, timedelta
9from mimetypes import guess_extension
10from typing import Any, ClassVar
12import sqlalchemy as sa
13from slixmpp.exceptions import XMPPError
14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
15from sqlalchemy import ColumnElement, Engine, delete, event, select, update
16from sqlalchemy.exc import InvalidRequestError
17from sqlalchemy.orm import Session, attributes, joinedload, load_only, sessionmaker
19from ..core import config
20from ..util.archive_msg import HistoryMessage
21from ..util.types import MamMetadata, Sticker
22from .meta import Base
23from .models import (
24 ArchivedMessage,
25 ArchivedMessageSource,
26 Avatar,
27 Bob,
28 Contact,
29 ContactSent,
30 DirectMessages,
31 DirectThreads,
32 GatewayUser,
33 GroupMessages,
34 GroupMessagesOrigin,
35 GroupThreads,
36 Participant,
37 Room,
38 Space,
39)
42class UpdatedMixin:
43 model: type[Base] = NotImplemented
45 def __init__(self, session: Session) -> None:
46 self.reset_updated(session)
48 def get_by_pk(self, session: Session, pk: int) -> type[Base]:
49 stmt = select(self.model).where(self.model.id == pk) # type:ignore[attr-defined]
50 return session.scalar(stmt) # type:ignore[no-any-return]
52 def reset_updated(self, session: Session) -> None:
53 session.execute(update(self.model).values(updated=False))
56class SlidgeStore:
57 def __init__(self, engine: Engine) -> None:
58 self._engine = engine
59 self.session = sessionmaker[Any](engine)
61 self.users = UserStore(self.session)
62 self.avatars = AvatarStore(self.session)
63 self.id_map = IdMapStore()
64 self.bob = BobStore()
65 with self.session() as session:
66 self.contacts = ContactStore(session)
67 self.mam = MAMStore(session, self.session)
68 self.rooms = RoomStore(session)
69 self.participants = ParticipantStore(session)
70 self.spaces = SpaceStore(session)
71 session.commit()
74class UserStore:
75 def __init__(self, session_maker: sessionmaker[Any]) -> None:
76 self.session = session_maker
78 def update(self, user: GatewayUser) -> None:
79 with self.session(expire_on_commit=False) as session:
80 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
81 try:
82 attributes.flag_modified(user, "legacy_module_data")
83 attributes.flag_modified(user, "preferences")
84 except InvalidRequestError:
85 pass
86 session.add(user)
87 session.commit()
90class AvatarStore:
91 def __init__(self, session_maker: sessionmaker[Any]) -> None:
92 self.session = session_maker
95LegacyToXmppType = (
96 type[DirectMessages]
97 | type[DirectThreads]
98 | type[GroupMessages]
99 | type[GroupThreads]
100 | type[GroupMessagesOrigin]
101)
104class IdMapStore:
105 @staticmethod
106 def _set(
107 session: Session,
108 foreign_key: int,
109 legacy_id: str,
110 xmpp_ids: list[str],
111 type_: LegacyToXmppType,
112 ) -> None:
113 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
114 ids = list(
115 session.scalars(
116 select(type_.id).filter(
117 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
118 )
119 )
120 )
121 if ids:
122 log.debug("Resetting legacy ID %s", legacy_id)
123 session.execute(delete(type_).where(type_.id.in_(ids)))
124 for xmpp_id in xmpp_ids:
125 msg = type_(xmpp_id=xmpp_id, **kwargs)
126 session.add(msg)
128 def set_thread(
129 self,
130 session: Session,
131 foreign_key: int,
132 legacy_id: str,
133 xmpp_id: str,
134 group: bool,
135 ) -> None:
136 self._set(
137 session,
138 foreign_key,
139 legacy_id,
140 [xmpp_id],
141 GroupThreads if group else DirectThreads,
142 )
144 def set_msg(
145 self,
146 session: Session,
147 foreign_key: int,
148 legacy_id: str,
149 xmpp_ids: list[str],
150 group: bool,
151 ) -> None:
152 self._set(
153 session,
154 foreign_key,
155 legacy_id,
156 xmpp_ids,
157 GroupMessages if group else DirectMessages,
158 )
160 def set_origin(
161 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str
162 ) -> None:
163 self._set(
164 session,
165 foreign_key,
166 legacy_id,
167 [xmpp_id],
168 GroupMessagesOrigin,
169 )
171 def get_origin(
172 self, session: Session, foreign_key: int, legacy_id: str
173 ) -> list[str]:
174 return self._get(
175 session,
176 foreign_key,
177 legacy_id,
178 GroupMessagesOrigin,
179 )
181 @staticmethod
182 def _get(
183 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
184 ) -> list[str]:
185 return list(
186 session.scalars(
187 select(type_.xmpp_id).filter_by(
188 foreign_key=foreign_key, legacy_id=str(legacy_id)
189 )
190 )
191 )
193 def get_xmpp(
194 self, session: Session, foreign_key: int, legacy_id: str, group: bool
195 ) -> list[str]:
196 return self._get(
197 session,
198 foreign_key,
199 legacy_id,
200 GroupMessages if group else DirectMessages,
201 )
203 @staticmethod
204 def _get_legacy(
205 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
206 ) -> str | None:
207 return session.scalar(
208 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
209 )
211 def get_legacy(
212 self,
213 session: Session,
214 foreign_key: int,
215 xmpp_id: str,
216 group: bool,
217 origin: bool = False,
218 ) -> str | None:
219 if origin and group:
220 return self._get_legacy(
221 session,
222 foreign_key,
223 xmpp_id,
224 GroupMessagesOrigin,
225 )
226 return self._get_legacy(
227 session,
228 foreign_key,
229 xmpp_id,
230 GroupMessages if group else DirectMessages,
231 )
233 def get_thread(
234 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
235 ) -> str | None:
236 return self._get_legacy(
237 session,
238 foreign_key,
239 xmpp_id,
240 GroupThreads if group else DirectThreads,
241 )
243 @staticmethod
244 def was_sent_by_user(
245 session: Session, foreign_key: int, legacy_id: str, group: bool
246 ) -> bool:
247 type_ = GroupMessages if group else DirectMessages
248 return (
249 session.scalar(
250 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
251 )
252 is not None
253 )
256class ContactStore(UpdatedMixin):
257 model = Contact
259 def __init__(self, session: Session) -> None:
260 super().__init__(session)
261 session.execute(update(Contact).values(cached_presence=False))
262 session.execute(update(Contact).values(caps_ver=None))
264 @staticmethod
265 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
266 if (
267 session.query(ContactSent.id)
268 .where(ContactSent.contact_id == contact_pk)
269 .where(ContactSent.msg_id == msg_id)
270 .first()
271 ) is not None:
272 log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
273 return
274 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
275 session.add(new)
277 @staticmethod
278 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
279 result = []
280 to_del = []
281 for row in session.execute(
282 select(ContactSent)
283 .where(ContactSent.contact_id == contact_pk)
284 .order_by(ContactSent.id)
285 ).scalars():
286 to_del.append(row.id)
287 result.append(row.msg_id)
288 if row.msg_id == msg_id:
289 break
290 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
291 return result
294class MAMStore:
295 def __init__(self, session: Session, session_maker: sessionmaker[Any]) -> None:
296 self.session = session_maker
297 self.reset_source(session)
299 @staticmethod
300 def reset_source(session: Session) -> None:
301 session.execute(
302 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
303 )
305 @staticmethod
306 def nuke_older_than(session: Session, days: int) -> None:
307 session.execute(
308 delete(ArchivedMessage).where(
309 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
310 )
311 )
313 @staticmethod
314 def add_message(
315 session: Session,
316 room_pk: int,
317 message: HistoryMessage,
318 archive_only: bool,
319 legacy_msg_id: str | None,
320 ) -> None:
321 source = (
322 ArchivedMessageSource.BACKFILL
323 if archive_only
324 else ArchivedMessageSource.LIVE
325 )
326 existing = session.execute(
327 select(ArchivedMessage)
328 .where(ArchivedMessage.room_id == room_pk)
329 .where(ArchivedMessage.stanza_id == message.id)
330 ).scalar()
331 if existing is None and legacy_msg_id is not None:
332 existing = session.execute(
333 select(ArchivedMessage)
334 .where(ArchivedMessage.room_id == room_pk)
335 .where(ArchivedMessage.legacy_id == str(legacy_msg_id))
336 ).scalar()
337 if existing is not None:
338 log.debug("Updating message %s in room %s", message.id, room_pk)
339 existing.timestamp = message.when
340 existing.stanza = str(message.stanza)
341 existing.author_jid = message.stanza.get_from()
342 existing.source = source
343 existing.legacy_id = legacy_msg_id
344 session.add(existing)
345 return
346 mam_msg = ArchivedMessage(
347 stanza_id=message.id,
348 timestamp=message.when,
349 stanza=str(message.stanza),
350 author_jid=message.stanza.get_from(),
351 room_id=room_pk,
352 source=source,
353 legacy_id=legacy_msg_id,
354 )
355 session.add(mam_msg)
357 @staticmethod
358 def get_messages(
359 session: Session,
360 room_pk: int,
361 start_date: datetime | None = None,
362 end_date: datetime | None = None,
363 before_id: str | None = None,
364 after_id: str | None = None,
365 ids: Collection[str] = (),
366 last_page_n: int | None = None,
367 sender: str | None = None,
368 flip: bool = False,
369 ) -> Iterator[HistoryMessage]:
370 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
371 if start_date is not None:
372 q = q.where(ArchivedMessage.timestamp >= start_date)
373 if end_date is not None:
374 q = q.where(ArchivedMessage.timestamp <= end_date)
375 if before_id is not None:
376 stamp = session.execute(
377 select(ArchivedMessage.timestamp).where(
378 ArchivedMessage.stanza_id == before_id,
379 ArchivedMessage.room_id == room_pk,
380 )
381 ).scalar_one_or_none()
382 if stamp is None:
383 raise XMPPError(
384 "item-not-found",
385 f"Message {before_id} not found",
386 )
387 q = q.where(ArchivedMessage.timestamp < stamp)
388 if after_id is not None:
389 stamp = session.execute(
390 select(ArchivedMessage.timestamp).where(
391 ArchivedMessage.stanza_id == after_id,
392 ArchivedMessage.room_id == room_pk,
393 )
394 ).scalar_one_or_none()
395 if stamp is None:
396 raise XMPPError(
397 "item-not-found",
398 f"Message {after_id} not found",
399 )
400 q = q.where(ArchivedMessage.timestamp > stamp)
401 if ids:
402 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
403 if sender is not None:
404 q = q.where(ArchivedMessage.author_jid == sender)
405 if flip:
406 q = q.order_by(ArchivedMessage.timestamp.desc())
407 else:
408 q = q.order_by(ArchivedMessage.timestamp.asc())
409 msgs = list(session.execute(q).scalars())
410 if ids and len(msgs) != len(ids):
411 raise XMPPError(
412 "item-not-found",
413 "One of the requested messages IDs could not be found "
414 "with the given constraints.",
415 )
416 if last_page_n is not None:
417 msgs = msgs[:last_page_n] if flip else msgs[-last_page_n:]
418 for h in msgs:
419 yield HistoryMessage(
420 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=UTC)
421 )
423 @staticmethod
424 def get_first(
425 session: Session, room_pk: int, with_legacy_id: bool = False
426 ) -> ArchivedMessage | None:
427 q = (
428 select(ArchivedMessage)
429 .where(ArchivedMessage.room_id == room_pk)
430 .order_by(ArchivedMessage.timestamp.asc())
431 )
432 if with_legacy_id:
433 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
434 return session.execute(q).scalar()
436 @staticmethod
437 def get_last(
438 session: Session, room_pk: int, source: ArchivedMessageSource | None = None
439 ) -> ArchivedMessage | None:
440 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
442 if source is not None:
443 q = q.where(ArchivedMessage.source == source)
445 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
447 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
448 r = []
449 first = self.get_first(session, room_pk)
450 if first is not None:
451 r.append(MamMetadata(first.stanza_id, first.timestamp))
452 last = self.get_last(session, room_pk)
453 if last is not None:
454 r.append(MamMetadata(last.stanza_id, last.timestamp))
455 return r
457 @staticmethod
458 def get_most_recent_with_legacy_id(
459 session: Session, room_pk: int, source: ArchivedMessageSource | None = None
460 ) -> ArchivedMessage | None:
461 q = (
462 select(ArchivedMessage)
463 .where(ArchivedMessage.room_id == room_pk)
464 .where(ArchivedMessage.legacy_id.isnot(None))
465 )
466 if source is not None:
467 q = q.where(ArchivedMessage.source == source)
468 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
470 @staticmethod
471 def get_least_recent_with_legacy_id_after(
472 session: Session,
473 room_pk: int,
474 after_id: str,
475 source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
476 ) -> ArchivedMessage | None:
477 after_timestamp = (
478 session.query(ArchivedMessage.timestamp)
479 .filter(ArchivedMessage.room_id == room_pk)
480 .filter(ArchivedMessage.legacy_id == after_id)
481 .scalar()
482 )
483 q = (
484 select(ArchivedMessage)
485 .where(ArchivedMessage.room_id == room_pk)
486 .where(ArchivedMessage.legacy_id.isnot(None))
487 .where(ArchivedMessage.source == source)
488 .where(ArchivedMessage.timestamp > after_timestamp)
489 )
490 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
492 @staticmethod
493 def get_by_legacy_id(
494 session: Session, room_pk: int, legacy_id: str
495 ) -> ArchivedMessage | None:
496 return (
497 session.query(ArchivedMessage)
498 .filter(ArchivedMessage.room_id == room_pk)
499 .filter(ArchivedMessage.legacy_id == legacy_id)
500 .first()
501 )
503 @staticmethod
504 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]:
505 q = (
506 select(ArchivedMessage.id, ArchivedMessage.stanza_id)
507 .where(ArchivedMessage.room_id == room_pk)
508 .where(~ArchivedMessage.displayed_by_user)
509 .where(ArchivedMessage.legacy_id.is_not(None))
510 .order_by(ArchivedMessage.timestamp.asc())
511 )
513 ref = session.scalar(
514 select(ArchivedMessage)
515 .where(ArchivedMessage.room_id == room_pk)
516 .where(ArchivedMessage.stanza_id == stanza_id)
517 )
519 if ref is None:
520 log.debug(
521 "(pop unread in muc): message not found, returning all MAM messages."
522 )
523 rows = session.execute(q)
524 else:
525 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp))
527 pks: list[int] = []
528 stanza_ids: list[str] = []
530 for id_, stanza_id in rows:
531 pks.append(id_)
532 stanza_ids.append(stanza_id)
534 session.execute(
535 update(ArchivedMessage)
536 .where(ArchivedMessage.id.in_(pks))
537 .values(displayed_by_user=True)
538 )
539 return stanza_ids
541 @staticmethod
542 def is_displayed_by_user(
543 session: Session, room_jid: str, legacy_msg_id: str
544 ) -> bool:
545 return any(
546 session.execute(
547 select(ArchivedMessage.displayed_by_user)
548 .join(Room)
549 .where(Room.jid == room_jid)
550 .where(ArchivedMessage.legacy_id == legacy_msg_id)
551 ).scalars()
552 )
555class RoomStore(UpdatedMixin):
556 model = Room
558 def reset_updated(self, session: Session) -> None:
559 super().reset_updated(session)
560 session.execute(
561 update(Room).values(
562 subject_setter=None,
563 user_resources=None,
564 history_filled=False,
565 participants_filled=False,
566 )
567 )
569 @staticmethod
570 def get_all(session: Session, user_pk: int) -> Iterator[Room]:
571 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
573 @staticmethod
574 def get(session: Session, user_pk: int, legacy_id: str) -> Room:
575 return session.execute(
576 select(Room)
577 .where(Room.user_account_id == user_pk)
578 .where(Room.legacy_id == legacy_id)
579 ).scalar_one()
581 @staticmethod
582 def nick_available(session: Session, room_pk: int, nickname: str) -> bool:
583 return (
584 session.execute(
585 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname)
586 )
587 ).one_or_none() is None
590class ParticipantStore:
591 def __init__(self, session: Session) -> None:
592 session.execute(delete(Participant))
594 @staticmethod
595 def get_all(
596 session: Session, room_pk: int, user_included: bool = True
597 ) -> Iterator[Participant]:
598 query = select(Participant).where(Participant.room_id == room_pk)
599 if not user_included:
600 query = query.where(~Participant.is_user)
601 yield from session.scalars(query).unique()
603 @staticmethod
604 def delete(session: Session, pk: int) -> None:
605 session.execute(delete(Participant).where(Participant.id == pk))
608class BobStore:
609 _ATTR_MAP: ClassVar[dict[str, str]] = {
610 "sha-1": "sha_1",
611 "sha1": "sha_1",
612 "sha-256": "sha_256",
613 "sha256": "sha_256",
614 "sha-512": "sha_512",
615 "sha512": "sha_512",
616 }
618 _ALG_MAP: ClassVar[dict[str, Callable[[bytes], hashlib._Hash]]] = {
619 "sha_1": hashlib.sha1,
620 "sha_256": hashlib.sha256,
621 "sha_512": hashlib.sha512,
622 }
624 def __init__(self) -> None:
625 if (config.HOME_DIR / "slidge_stickers").exists():
626 shutil.move(
627 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store"
628 )
629 self.root_dir = config.HOME_DIR / "bob_store"
630 self.root_dir.mkdir(exist_ok=True)
632 @staticmethod
633 def __split_cid(cid: str) -> list[str]:
634 return cid.removesuffix("@bob.xmpp.org").split("+")
636 def __get_condition(self, cid: str) -> ColumnElement[bool]:
637 alg_name, digest = self.__split_cid(cid)
638 attr = self._ATTR_MAP.get(alg_name)
639 if attr is None:
640 log.warning("Unknown hash algorithm: %s", alg_name)
641 raise ValueError
642 return getattr(Bob, attr) == digest # type:ignore[no-any-return]
644 def get(self, session: Session, cid: str) -> Bob | None:
645 try:
646 return session.query(Bob).filter(self.__get_condition(cid)).scalar() # type:ignore[no-any-return]
647 except ValueError:
648 log.warning("Cannot get Bob with CID: %s", cid)
649 return None
651 def get_sticker(self, session: Session, cid: str) -> Sticker | None:
652 bob = self.get(session, cid)
653 if bob is None:
654 return None
655 return self.__sticker_from_bob(bob)
657 def __sticker_from_bob(self, bob: Bob) -> Sticker:
658 return Sticker(
659 self.root_dir / bob.file_name,
660 bob.content_type,
661 {h: getattr(bob, h) for h in self._ALG_MAP},
662 )
664 def get_bob(
665 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str
666 ) -> BitsOfBinary | None:
667 stored = self.get(session, cid)
668 if stored is None:
669 return None
670 bob = BitsOfBinary()
671 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
672 if stored.content_type is not None:
673 bob["type"] = stored.content_type
674 bob["cid"] = cid
675 return bob
677 def del_bob(
678 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str
679 ) -> None:
680 try:
681 file_name = session.scalar(
682 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
683 )
684 except ValueError:
685 log.warning("Cannot delete Bob with CID: %s", cid)
686 return None
687 if file_name is None:
688 log.warning("No BoB with CID: %s", cid)
689 return None
690 (self.root_dir / file_name).unlink()
692 def set_bob(
693 self,
694 session: Session,
695 _jid: object,
696 _node: object,
697 _ifrom: object,
698 bob: BitsOfBinary,
699 ) -> Sticker | None:
700 return self.set_sticker(session, bob["cid"], bob["data"], bob["type"])
702 def set_sticker(
703 self,
704 session: Session,
705 cid: str,
706 bytes_: bytes,
707 content_type: str | None,
708 ) -> Sticker | None:
709 try:
710 alg_name, digest = self.__split_cid(cid)
711 except ValueError:
712 log.warning("Invalid CID provided: %s", cid)
713 return None
714 attr = self._ATTR_MAP.get(alg_name)
715 if attr is None:
716 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
717 return None
718 existing = self.get(session, cid)
719 if existing:
720 log.debug("Bob already exists")
721 return None
722 path = self.root_dir / uuid.uuid4().hex
723 if content_type is None:
724 try:
725 import magic
726 except ImportError:
727 content_type = "application/octet-stream"
728 else:
729 content_type = magic.from_buffer(bytes_, mime=True)
730 path = path.with_suffix(guess_extension(content_type) or "")
731 path.write_bytes(bytes_)
732 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
733 if hashes[attr] != digest:
734 path.unlink(missing_ok=True)
735 raise ValueError("Provided CID does not match calculated hash")
736 row = Bob(file_name=path.name, content_type=content_type, **hashes)
737 session.add(row)
738 return self.__sticker_from_bob(row)
741class SpaceStore(UpdatedMixin):
742 model = Space
744 def __init__(self, session: Session) -> None:
745 session.execute(delete(Space))
747 @staticmethod
748 def add_or_get(session: Session, user_pk: int, legacy_id: str) -> Space:
749 space = session.execute(
750 select(Space)
751 .where(Space.user_account_id == user_pk)
752 .where(Space.legacy_id == legacy_id)
753 ).scalar_one_or_none()
754 if space is None:
755 space = Space(user_account_id=user_pk, legacy_id=legacy_id)
756 session.add(space)
757 session.commit()
758 return space
760 @staticmethod
761 def get_all(session: Session, user_pk: int) -> Iterable[Space]:
762 return session.execute(
763 select(Space).where(Space.user_account_id == user_pk)
764 ).scalars()
766 @staticmethod
767 def get_by_legacy_id(
768 session: Session, user_pk: int, legacy_id: str, full: bool = False
769 ) -> Space | None:
770 stmt = (
771 select(Space)
772 .where(Space.user_account_id == user_pk)
773 .where(Space.legacy_id == legacy_id)
774 )
775 if full:
776 stmt = stmt.options(
777 joinedload(Space.owners),
778 joinedload(Space.creator),
779 )
780 return session.execute(stmt).unique().scalar_one_or_none()
782 @staticmethod
783 def get_unupdated(session: Session, user_pk: int) -> list[Space]:
784 return list(
785 session.execute(
786 select(Space)
787 .where(Space.user_account_id == user_pk)
788 .where(Space.updated.is_(False))
789 ).scalars()
790 )
792 @staticmethod
793 def get_rooms(
794 session: Session,
795 user_pk: int,
796 legacy_id: str,
797 room_legacy_ids: Iterable[str] = (),
798 ) -> list[Room]:
799 q = (
800 select(Room)
801 .join(Room.space)
802 .where(Room.user_account_id == user_pk)
803 .where(Space.legacy_id == legacy_id)
804 .options(load_only(Room.jid, Room.name))
805 )
806 if room_legacy_ids:
807 q = q.where(Room.legacy_id.in_(room_legacy_ids))
808 return list(session.execute(q).scalars())
810 @staticmethod
811 def exists(session: Session, user_pk: int, legacy_id: str) -> bool:
812 return session.execute(
813 select(
814 sa.exists()
815 .where(Space.user_account_id == user_pk)
816 .where(Space.legacy_id == legacy_id)
817 )
818 ).scalar_one()
821@event.listens_for(sa.orm.Session, "after_flush")
822def _check_avatar_orphans(session: Session, flush_context: sa.ExecutionContext) -> None:
823 if not session.deleted:
824 return
826 potentially_orphaned = set()
827 for obj in session.deleted:
828 if isinstance(obj, (Contact, Room)) and obj.avatar_id:
829 potentially_orphaned.add(obj.avatar_id)
830 if not potentially_orphaned:
831 return
833 result = session.execute(
834 sa.delete(Avatar).where(
835 sa.and_(
836 Avatar.id.in_(potentially_orphaned),
837 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)),
838 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)),
839 )
840 )
841 )
842 deleted_count = result.rowcount # type:ignore[attr-defined]
843 log.debug("Auto-deleted %s orphaned avatars", deleted_count)
846log = logging.getLogger(__name__)