Coverage for slidge / db / store.py: 89%
345 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1from __future__ import annotations
3import hashlib
4import logging
5import shutil
6import uuid
7from collections.abc import Callable, Collection, Iterator
8from datetime import UTC, datetime, timedelta
9from mimetypes import guess_extension
10from typing import Any, ClassVar
12import sqlalchemy as sa
13from slixmpp.exceptions import XMPPError
14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
15from sqlalchemy import ColumnElement, Engine, delete, event, select, update
16from sqlalchemy.exc import InvalidRequestError
17from sqlalchemy.orm import Session, attributes, sessionmaker
19from ..core import config
20from ..util.archive_msg import HistoryMessage
21from ..util.types import MamMetadata, Sticker
22from .meta import Base
23from .models import (
24 ArchivedMessage,
25 ArchivedMessageSource,
26 Avatar,
27 Bob,
28 Contact,
29 ContactSent,
30 DirectMessages,
31 DirectThreads,
32 GatewayUser,
33 GroupMessages,
34 GroupMessagesOrigin,
35 GroupThreads,
36 Participant,
37 Room,
38)
41class UpdatedMixin:
42 model: type[Base] = NotImplemented
44 def __init__(self, session: Session) -> None:
45 self.reset_updated(session)
47 def get_by_pk(self, session: Session, pk: int) -> type[Base]:
48 stmt = select(self.model).where(self.model.id == pk) # type:ignore[attr-defined]
49 return session.scalar(stmt) # type:ignore[no-any-return]
51 def reset_updated(self, session: Session) -> None:
52 session.execute(update(self.model).values(updated=False))
55class SlidgeStore:
56 def __init__(self, engine: Engine) -> None:
57 self._engine = engine
58 self.session = sessionmaker[Any](engine)
60 self.users = UserStore(self.session)
61 self.avatars = AvatarStore(self.session)
62 self.id_map = IdMapStore()
63 self.bob = BobStore()
64 with self.session() as session:
65 self.contacts = ContactStore(session)
66 self.mam = MAMStore(session, self.session)
67 self.rooms = RoomStore(session)
68 self.participants = ParticipantStore(session)
69 session.commit()
72class UserStore:
73 def __init__(self, session_maker: sessionmaker[Any]) -> None:
74 self.session = session_maker
76 def update(self, user: GatewayUser) -> None:
77 with self.session(expire_on_commit=False) as session:
78 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
79 try:
80 attributes.flag_modified(user, "legacy_module_data")
81 attributes.flag_modified(user, "preferences")
82 except InvalidRequestError:
83 pass
84 session.add(user)
85 session.commit()
88class AvatarStore:
89 def __init__(self, session_maker: sessionmaker[Any]) -> None:
90 self.session = session_maker
93LegacyToXmppType = (
94 type[DirectMessages]
95 | type[DirectThreads]
96 | type[GroupMessages]
97 | type[GroupThreads]
98 | type[GroupMessagesOrigin]
99)
102class IdMapStore:
103 @staticmethod
104 def _set(
105 session: Session,
106 foreign_key: int,
107 legacy_id: str,
108 xmpp_ids: list[str],
109 type_: LegacyToXmppType,
110 ) -> None:
111 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
112 ids = session.scalars(
113 select(type_.id).filter(
114 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
115 )
116 )
117 if ids:
118 log.debug("Resetting legacy ID %s", legacy_id)
119 session.execute(delete(type_).where(type_.id.in_(ids)))
120 for xmpp_id in xmpp_ids:
121 msg = type_(xmpp_id=xmpp_id, **kwargs)
122 session.add(msg)
124 def set_thread(
125 self,
126 session: Session,
127 foreign_key: int,
128 legacy_id: str,
129 xmpp_id: str,
130 group: bool,
131 ) -> None:
132 self._set(
133 session,
134 foreign_key,
135 legacy_id,
136 [xmpp_id],
137 GroupThreads if group else DirectThreads,
138 )
140 def set_msg(
141 self,
142 session: Session,
143 foreign_key: int,
144 legacy_id: str,
145 xmpp_ids: list[str],
146 group: bool,
147 ) -> None:
148 self._set(
149 session,
150 foreign_key,
151 legacy_id,
152 xmpp_ids,
153 GroupMessages if group else DirectMessages,
154 )
156 def set_origin(
157 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str
158 ) -> None:
159 self._set(
160 session,
161 foreign_key,
162 legacy_id,
163 [xmpp_id],
164 GroupMessagesOrigin,
165 )
167 def get_origin(
168 self, session: Session, foreign_key: int, legacy_id: str
169 ) -> list[str]:
170 return self._get(
171 session,
172 foreign_key,
173 legacy_id,
174 GroupMessagesOrigin,
175 )
177 @staticmethod
178 def _get(
179 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
180 ) -> list[str]:
181 return list(
182 session.scalars(
183 select(type_.xmpp_id).filter_by(
184 foreign_key=foreign_key, legacy_id=str(legacy_id)
185 )
186 )
187 )
189 def get_xmpp(
190 self, session: Session, foreign_key: int, legacy_id: str, group: bool
191 ) -> list[str]:
192 return self._get(
193 session,
194 foreign_key,
195 legacy_id,
196 GroupMessages if group else DirectMessages,
197 )
199 @staticmethod
200 def _get_legacy(
201 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
202 ) -> str | None:
203 return session.scalar(
204 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
205 )
207 def get_legacy(
208 self,
209 session: Session,
210 foreign_key: int,
211 xmpp_id: str,
212 group: bool,
213 origin: bool = False,
214 ) -> str | None:
215 if origin and group:
216 return self._get_legacy(
217 session,
218 foreign_key,
219 xmpp_id,
220 GroupMessagesOrigin,
221 )
222 return self._get_legacy(
223 session,
224 foreign_key,
225 xmpp_id,
226 GroupMessages if group else DirectMessages,
227 )
229 def get_thread(
230 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
231 ) -> str | None:
232 return self._get_legacy(
233 session,
234 foreign_key,
235 xmpp_id,
236 GroupThreads if group else DirectThreads,
237 )
239 @staticmethod
240 def was_sent_by_user(
241 session: Session, foreign_key: int, legacy_id: str, group: bool
242 ) -> bool:
243 type_ = GroupMessages if group else DirectMessages
244 return (
245 session.scalar(
246 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
247 )
248 is not None
249 )
252class ContactStore(UpdatedMixin):
253 model = Contact
255 def __init__(self, session: Session) -> None:
256 super().__init__(session)
257 session.execute(update(Contact).values(cached_presence=False))
258 session.execute(update(Contact).values(caps_ver=None))
260 @staticmethod
261 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
262 if (
263 session.query(ContactSent.id)
264 .where(ContactSent.contact_id == contact_pk)
265 .where(ContactSent.msg_id == msg_id)
266 .first()
267 ) is not None:
268 log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
269 return
270 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
271 session.add(new)
273 @staticmethod
274 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
275 result = []
276 to_del = []
277 for row in session.execute(
278 select(ContactSent)
279 .where(ContactSent.contact_id == contact_pk)
280 .order_by(ContactSent.id)
281 ).scalars():
282 to_del.append(row.id)
283 result.append(row.msg_id)
284 if row.msg_id == msg_id:
285 break
286 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
287 return result
290class MAMStore:
291 def __init__(self, session: Session, session_maker: sessionmaker[Any]) -> None:
292 self.session = session_maker
293 self.reset_source(session)
295 @staticmethod
296 def reset_source(session: Session) -> None:
297 session.execute(
298 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
299 )
301 @staticmethod
302 def nuke_older_than(session: Session, days: int) -> None:
303 session.execute(
304 delete(ArchivedMessage).where(
305 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
306 )
307 )
309 @staticmethod
310 def add_message(
311 session: Session,
312 room_pk: int,
313 message: HistoryMessage,
314 archive_only: bool,
315 legacy_msg_id: str | None,
316 ) -> None:
317 source = (
318 ArchivedMessageSource.BACKFILL
319 if archive_only
320 else ArchivedMessageSource.LIVE
321 )
322 existing = session.execute(
323 select(ArchivedMessage)
324 .where(ArchivedMessage.room_id == room_pk)
325 .where(ArchivedMessage.stanza_id == message.id)
326 ).scalar()
327 if existing is None and legacy_msg_id is not None:
328 existing = session.execute(
329 select(ArchivedMessage)
330 .where(ArchivedMessage.room_id == room_pk)
331 .where(ArchivedMessage.legacy_id == str(legacy_msg_id))
332 ).scalar()
333 if existing is not None:
334 log.debug("Updating message %s in room %s", message.id, room_pk)
335 existing.timestamp = message.when
336 existing.stanza = str(message.stanza)
337 existing.author_jid = message.stanza.get_from()
338 existing.source = source
339 existing.legacy_id = legacy_msg_id
340 session.add(existing)
341 return
342 mam_msg = ArchivedMessage(
343 stanza_id=message.id,
344 timestamp=message.when,
345 stanza=str(message.stanza),
346 author_jid=message.stanza.get_from(),
347 room_id=room_pk,
348 source=source,
349 legacy_id=legacy_msg_id,
350 )
351 session.add(mam_msg)
353 @staticmethod
354 def get_messages(
355 session: Session,
356 room_pk: int,
357 start_date: datetime | None = None,
358 end_date: datetime | None = None,
359 before_id: str | None = None,
360 after_id: str | None = None,
361 ids: Collection[str] = (),
362 last_page_n: int | None = None,
363 sender: str | None = None,
364 flip: bool = False,
365 ) -> Iterator[HistoryMessage]:
366 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
367 if start_date is not None:
368 q = q.where(ArchivedMessage.timestamp >= start_date)
369 if end_date is not None:
370 q = q.where(ArchivedMessage.timestamp <= end_date)
371 if before_id is not None:
372 stamp = session.execute(
373 select(ArchivedMessage.timestamp).where(
374 ArchivedMessage.stanza_id == before_id,
375 ArchivedMessage.room_id == room_pk,
376 )
377 ).scalar_one_or_none()
378 if stamp is None:
379 raise XMPPError(
380 "item-not-found",
381 f"Message {before_id} not found",
382 )
383 q = q.where(ArchivedMessage.timestamp < stamp)
384 if after_id is not None:
385 stamp = session.execute(
386 select(ArchivedMessage.timestamp).where(
387 ArchivedMessage.stanza_id == after_id,
388 ArchivedMessage.room_id == room_pk,
389 )
390 ).scalar_one_or_none()
391 if stamp is None:
392 raise XMPPError(
393 "item-not-found",
394 f"Message {after_id} not found",
395 )
396 q = q.where(ArchivedMessage.timestamp > stamp)
397 if ids:
398 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
399 if sender is not None:
400 q = q.where(ArchivedMessage.author_jid == sender)
401 if flip:
402 q = q.order_by(ArchivedMessage.timestamp.desc())
403 else:
404 q = q.order_by(ArchivedMessage.timestamp.asc())
405 msgs = list(session.execute(q).scalars())
406 if ids and len(msgs) != len(ids):
407 raise XMPPError(
408 "item-not-found",
409 "One of the requested messages IDs could not be found "
410 "with the given constraints.",
411 )
412 if last_page_n is not None:
413 if flip:
414 msgs = msgs[:last_page_n]
415 else:
416 msgs = msgs[-last_page_n:]
417 for h in msgs:
418 yield HistoryMessage(
419 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=UTC)
420 )
422 @staticmethod
423 def get_first(
424 session: Session, room_pk: int, with_legacy_id: bool = False
425 ) -> ArchivedMessage | None:
426 q = (
427 select(ArchivedMessage)
428 .where(ArchivedMessage.room_id == room_pk)
429 .order_by(ArchivedMessage.timestamp.asc())
430 )
431 if with_legacy_id:
432 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
433 return session.execute(q).scalar()
435 @staticmethod
436 def get_last(
437 session: Session, room_pk: int, source: ArchivedMessageSource | None = None
438 ) -> ArchivedMessage | None:
439 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
441 if source is not None:
442 q = q.where(ArchivedMessage.source == source)
444 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
446 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
447 r = []
448 first = self.get_first(session, room_pk)
449 if first is not None:
450 r.append(MamMetadata(first.stanza_id, first.timestamp))
451 last = self.get_last(session, room_pk)
452 if last is not None:
453 r.append(MamMetadata(last.stanza_id, last.timestamp))
454 return r
456 @staticmethod
457 def get_most_recent_with_legacy_id(
458 session: Session, room_pk: int, source: ArchivedMessageSource | None = None
459 ) -> ArchivedMessage | None:
460 q = (
461 select(ArchivedMessage)
462 .where(ArchivedMessage.room_id == room_pk)
463 .where(ArchivedMessage.legacy_id.isnot(None))
464 )
465 if source is not None:
466 q = q.where(ArchivedMessage.source == source)
467 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
469 @staticmethod
470 def get_least_recent_with_legacy_id_after(
471 session: Session,
472 room_pk: int,
473 after_id: str,
474 source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
475 ) -> ArchivedMessage | None:
476 after_timestamp = (
477 session.query(ArchivedMessage.timestamp)
478 .filter(ArchivedMessage.room_id == room_pk)
479 .filter(ArchivedMessage.legacy_id == after_id)
480 .scalar()
481 )
482 q = (
483 select(ArchivedMessage)
484 .where(ArchivedMessage.room_id == room_pk)
485 .where(ArchivedMessage.legacy_id.isnot(None))
486 .where(ArchivedMessage.source == source)
487 .where(ArchivedMessage.timestamp > after_timestamp)
488 )
489 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
491 @staticmethod
492 def get_by_legacy_id(
493 session: Session, room_pk: int, legacy_id: str
494 ) -> ArchivedMessage | None:
495 return (
496 session.query(ArchivedMessage)
497 .filter(ArchivedMessage.room_id == room_pk)
498 .filter(ArchivedMessage.legacy_id == legacy_id)
499 .first()
500 )
502 @staticmethod
503 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]:
504 q = (
505 select(ArchivedMessage.id, ArchivedMessage.stanza_id)
506 .where(ArchivedMessage.room_id == room_pk)
507 .where(~ArchivedMessage.displayed_by_user)
508 .where(ArchivedMessage.legacy_id.is_not(None))
509 .order_by(ArchivedMessage.timestamp.asc())
510 )
512 ref = session.scalar(
513 select(ArchivedMessage)
514 .where(ArchivedMessage.room_id == room_pk)
515 .where(ArchivedMessage.stanza_id == stanza_id)
516 )
518 if ref is None:
519 log.debug(
520 "(pop unread in muc): message not found, returning all MAM messages."
521 )
522 rows = session.execute(q)
523 else:
524 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp))
526 pks: list[int] = []
527 stanza_ids: list[str] = []
529 for id_, stanza_id in rows:
530 pks.append(id_)
531 stanza_ids.append(stanza_id)
533 session.execute(
534 update(ArchivedMessage)
535 .where(ArchivedMessage.id.in_(pks))
536 .values(displayed_by_user=True)
537 )
538 return stanza_ids
540 @staticmethod
541 def is_displayed_by_user(
542 session: Session, room_jid: str, legacy_msg_id: str
543 ) -> bool:
544 return any(
545 session.execute(
546 select(ArchivedMessage.displayed_by_user)
547 .join(Room)
548 .where(Room.jid == room_jid)
549 .where(ArchivedMessage.legacy_id == legacy_msg_id)
550 ).scalars()
551 )
554class RoomStore(UpdatedMixin):
555 model = Room
557 def reset_updated(self, session: Session) -> None:
558 super().reset_updated(session)
559 session.execute(
560 update(Room).values(
561 subject_setter=None,
562 user_resources=None,
563 history_filled=False,
564 participants_filled=False,
565 )
566 )
568 @staticmethod
569 def get_all(session: Session, user_pk: int) -> Iterator[Room]:
570 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
572 @staticmethod
573 def get(session: Session, user_pk: int, legacy_id: str) -> Room:
574 return session.execute(
575 select(Room)
576 .where(Room.user_account_id == user_pk)
577 .where(Room.legacy_id == legacy_id)
578 ).scalar_one()
580 @staticmethod
581 def nick_available(session: Session, room_pk: int, nickname: str) -> bool:
582 return (
583 session.execute(
584 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname)
585 )
586 ).one_or_none() is None
589class ParticipantStore:
590 def __init__(self, session: Session) -> None:
591 session.execute(delete(Participant))
593 @staticmethod
594 def get_all(
595 session: Session, room_pk: int, user_included: bool = True
596 ) -> Iterator[Participant]:
597 query = select(Participant).where(Participant.room_id == room_pk)
598 if not user_included:
599 query = query.where(~Participant.is_user)
600 yield from session.scalars(query).unique()
603class BobStore:
604 _ATTR_MAP: ClassVar[dict[str, str]] = {
605 "sha-1": "sha_1",
606 "sha1": "sha_1",
607 "sha-256": "sha_256",
608 "sha256": "sha_256",
609 "sha-512": "sha_512",
610 "sha512": "sha_512",
611 }
613 _ALG_MAP: ClassVar[dict[str, Callable[[bytes], hashlib._Hash]]] = {
614 "sha_1": hashlib.sha1,
615 "sha_256": hashlib.sha256,
616 "sha_512": hashlib.sha512,
617 }
619 def __init__(self) -> None:
620 if (config.HOME_DIR / "slidge_stickers").exists():
621 shutil.move(
622 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store"
623 )
624 self.root_dir = config.HOME_DIR / "bob_store"
625 self.root_dir.mkdir(exist_ok=True)
627 @staticmethod
628 def __split_cid(cid: str) -> list[str]:
629 return cid.removesuffix("@bob.xmpp.org").split("+")
631 def __get_condition(self, cid: str) -> ColumnElement[bool]:
632 alg_name, digest = self.__split_cid(cid)
633 attr = self._ATTR_MAP.get(alg_name)
634 if attr is None:
635 log.warning("Unknown hash algorithm: %s", alg_name)
636 raise ValueError
637 return getattr(Bob, attr) == digest # type:ignore[no-any-return]
639 def get(self, session: Session, cid: str) -> Bob | None:
640 try:
641 return session.query(Bob).filter(self.__get_condition(cid)).scalar() # type:ignore[no-any-return]
642 except ValueError:
643 log.warning("Cannot get Bob with CID: %s", cid)
644 return None
646 def get_sticker(self, session: Session, cid: str) -> Sticker | None:
647 bob = self.get(session, cid)
648 if bob is None:
649 return None
650 return Sticker(
651 self.root_dir / bob.file_name,
652 bob.content_type,
653 {h: getattr(bob, h) for h in self._ALG_MAP},
654 )
656 def get_bob(
657 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str
658 ) -> BitsOfBinary | None:
659 stored = self.get(session, cid)
660 if stored is None:
661 return None
662 bob = BitsOfBinary()
663 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
664 if stored.content_type is not None:
665 bob["type"] = stored.content_type
666 bob["cid"] = cid
667 return bob
669 def del_bob(
670 self, session: Session, _jid: object, _node: object, _ifrom: object, cid: str
671 ) -> None:
672 try:
673 file_name = session.scalar(
674 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
675 )
676 except ValueError:
677 log.warning("Cannot delete Bob with CID: %s", cid)
678 return None
679 if file_name is None:
680 log.warning("No BoB with CID: %s", cid)
681 return None
682 (self.root_dir / file_name).unlink()
684 def set_bob(
685 self,
686 session: Session,
687 _jid: object,
688 _node: object,
689 _ifrom: object,
690 bob: BitsOfBinary,
691 ) -> None:
692 cid = bob["cid"]
693 try:
694 alg_name, digest = self.__split_cid(cid)
695 except ValueError:
696 log.warning("Invalid CID provided: %s", cid)
697 return
698 attr = self._ATTR_MAP.get(alg_name)
699 if attr is None:
700 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
701 return
702 existing = self.get(session, bob["cid"])
703 if existing:
704 log.debug("Bob already exists")
705 return
706 bytes_ = bob["data"]
707 path = self.root_dir / uuid.uuid4().hex
708 if bob["type"]:
709 path = path.with_suffix(guess_extension(bob["type"]) or "")
710 path.write_bytes(bytes_)
711 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
712 if hashes[attr] != digest:
713 path.unlink(missing_ok=True)
714 raise ValueError("Provided CID does not match calculated hash")
715 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
716 session.add(row)
719@event.listens_for(sa.orm.Session, "after_flush")
720def _check_avatar_orphans(session: Session, flush_context: sa.ExecutionContext) -> None:
721 if not session.deleted:
722 return
724 potentially_orphaned = set()
725 for obj in session.deleted:
726 if isinstance(obj, (Contact, Room)) and obj.avatar_id:
727 potentially_orphaned.add(obj.avatar_id)
728 if not potentially_orphaned:
729 return
731 result = session.execute(
732 sa.delete(Avatar).where(
733 sa.and_(
734 Avatar.id.in_(potentially_orphaned),
735 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)),
736 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)),
737 )
738 )
739 )
740 deleted_count = result.rowcount # type:ignore[attr-defined]
741 log.debug("Auto-deleted %s orphaned avatars", deleted_count)
744log = logging.getLogger(__name__)