Coverage for slidge / db / store.py: 89%
344 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-06 15:18 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-06 15:18 +0000
1from __future__ import annotations
3import hashlib
4import logging
5import shutil
6import uuid
7from datetime import datetime, timedelta, timezone
8from mimetypes import guess_extension
9from typing import Collection, Iterator, Optional, Type
11import sqlalchemy as sa
12from slixmpp.exceptions import XMPPError
13from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
14from sqlalchemy import Engine, delete, event, select, update
15from sqlalchemy.exc import InvalidRequestError
16from sqlalchemy.orm import Session, attributes, sessionmaker
18from ..core import config
19from ..util.archive_msg import HistoryMessage
20from ..util.types import MamMetadata, Sticker
21from .meta import Base
22from .models import (
23 ArchivedMessage,
24 ArchivedMessageSource,
25 Avatar,
26 Bob,
27 Contact,
28 ContactSent,
29 DirectMessages,
30 DirectThreads,
31 GatewayUser,
32 GroupMessages,
33 GroupMessagesOrigin,
34 GroupThreads,
35 Participant,
36 Room,
37)
40class UpdatedMixin:
41 model: Type[Base] = NotImplemented
43 def __init__(self, session: Session) -> None:
44 self.reset_updated(session)
46 def get_by_pk(self, session: Session, pk: int) -> Type[Base]:
47 stmt = select(self.model).where(self.model.id == pk) # type:ignore
48 return session.scalar(stmt)
50 def reset_updated(self, session: Session) -> None:
51 session.execute(update(self.model).values(updated=False))
54class SlidgeStore:
55 def __init__(self, engine: Engine) -> None:
56 self._engine = engine
57 self.session = sessionmaker(engine)
59 self.users = UserStore(self.session)
60 self.avatars = AvatarStore(self.session)
61 self.id_map = IdMapStore()
62 self.bob = BobStore()
63 with self.session() as session:
64 self.contacts = ContactStore(session)
65 self.mam = MAMStore(session, self.session)
66 self.rooms = RoomStore(session)
67 self.participants = ParticipantStore(session)
68 session.commit()
71class UserStore:
72 def __init__(self, session_maker) -> None:
73 self.session = session_maker
75 def update(self, user: GatewayUser) -> None:
76 with self.session(expire_on_commit=False) as session:
77 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
78 try:
79 attributes.flag_modified(user, "legacy_module_data")
80 attributes.flag_modified(user, "preferences")
81 except InvalidRequestError:
82 pass
83 session.add(user)
84 session.commit()
87class AvatarStore:
88 def __init__(self, session_maker) -> None:
89 self.session = session_maker
92LegacyToXmppType = (
93 Type[DirectMessages]
94 | Type[DirectThreads]
95 | Type[GroupMessages]
96 | Type[GroupThreads]
97 | Type[GroupMessagesOrigin]
98)
101class IdMapStore:
102 @staticmethod
103 def _set(
104 session: Session,
105 foreign_key: int,
106 legacy_id: str,
107 xmpp_ids: list[str],
108 type_: LegacyToXmppType,
109 ) -> None:
110 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
111 ids = session.scalars(
112 select(type_.id).filter(
113 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
114 )
115 )
116 if ids:
117 log.debug("Resetting legacy ID %s", legacy_id)
118 session.execute(delete(type_).where(type_.id.in_(ids)))
119 for xmpp_id in xmpp_ids:
120 msg = type_(xmpp_id=xmpp_id, **kwargs)
121 session.add(msg)
123 def set_thread(
124 self,
125 session: Session,
126 foreign_key: int,
127 legacy_id: str,
128 xmpp_id: str,
129 group: bool,
130 ) -> None:
131 self._set(
132 session,
133 foreign_key,
134 legacy_id,
135 [xmpp_id],
136 GroupThreads if group else DirectThreads,
137 )
139 def set_msg(
140 self,
141 session: Session,
142 foreign_key: int,
143 legacy_id: str,
144 xmpp_ids: list[str],
145 group: bool,
146 ) -> None:
147 self._set(
148 session,
149 foreign_key,
150 legacy_id,
151 xmpp_ids,
152 GroupMessages if group else DirectMessages,
153 )
155 def set_origin(
156 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str
157 ) -> None:
158 self._set(
159 session,
160 foreign_key,
161 legacy_id,
162 [xmpp_id],
163 GroupMessagesOrigin,
164 )
166 def get_origin(
167 self, session: Session, foreign_key: int, legacy_id: str
168 ) -> list[str]:
169 return self._get(
170 session,
171 foreign_key,
172 legacy_id,
173 GroupMessagesOrigin,
174 )
176 @staticmethod
177 def _get(
178 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
179 ) -> list[str]:
180 return list(
181 session.scalars(
182 select(type_.xmpp_id).filter_by(
183 foreign_key=foreign_key, legacy_id=legacy_id
184 )
185 )
186 )
188 def get_xmpp(
189 self, session: Session, foreign_key: int, legacy_id: str, group: bool
190 ) -> list[str]:
191 return self._get(
192 session,
193 foreign_key,
194 legacy_id,
195 GroupMessages if group else DirectMessages,
196 )
198 @staticmethod
199 def _get_legacy(
200 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
201 ) -> Optional[str]:
202 return session.scalar(
203 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
204 )
206 def get_legacy(
207 self,
208 session: Session,
209 foreign_key: int,
210 xmpp_id: str,
211 group: bool,
212 origin: bool = False,
213 ) -> Optional[str]:
214 if origin and group:
215 return self._get_legacy(
216 session,
217 foreign_key,
218 xmpp_id,
219 GroupMessagesOrigin,
220 )
221 return self._get_legacy(
222 session,
223 foreign_key,
224 xmpp_id,
225 GroupMessages if group else DirectMessages,
226 )
228 def get_thread(
229 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
230 ) -> Optional[str]:
231 return self._get_legacy(
232 session,
233 foreign_key,
234 xmpp_id,
235 GroupThreads if group else DirectThreads,
236 )
238 @staticmethod
239 def was_sent_by_user(
240 session: Session, foreign_key: int, legacy_id: str, group: bool
241 ) -> bool:
242 type_ = GroupMessages if group else DirectMessages
243 return (
244 session.scalar(
245 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
246 )
247 is not None
248 )
251class ContactStore(UpdatedMixin):
252 model = Contact
254 def __init__(self, session: Session) -> None:
255 super().__init__(session)
256 session.execute(update(Contact).values(cached_presence=False))
257 session.execute(update(Contact).values(caps_ver=None))
259 @staticmethod
260 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
261 if (
262 session.query(ContactSent.id)
263 .where(ContactSent.contact_id == contact_pk)
264 .where(ContactSent.msg_id == msg_id)
265 .first()
266 ) is not None:
267 log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
268 return
269 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
270 session.add(new)
272 @staticmethod
273 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
274 result = []
275 to_del = []
276 for row in session.execute(
277 select(ContactSent)
278 .where(ContactSent.contact_id == contact_pk)
279 .order_by(ContactSent.id)
280 ).scalars():
281 to_del.append(row.id)
282 result.append(row.msg_id)
283 if row.msg_id == msg_id:
284 break
285 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
286 return result
289class MAMStore:
290 def __init__(self, session: Session, session_maker) -> None:
291 self.session = session_maker
292 self.reset_source(session)
294 @staticmethod
295 def reset_source(session: Session) -> None:
296 session.execute(
297 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
298 )
300 @staticmethod
301 def nuke_older_than(session: Session, days: int) -> None:
302 session.execute(
303 delete(ArchivedMessage).where(
304 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
305 )
306 )
308 @staticmethod
309 def add_message(
310 session: Session,
311 room_pk: int,
312 message: HistoryMessage,
313 archive_only: bool,
314 legacy_msg_id: Optional[str],
315 ) -> None:
316 source = (
317 ArchivedMessageSource.BACKFILL
318 if archive_only
319 else ArchivedMessageSource.LIVE
320 )
321 existing = session.execute(
322 select(ArchivedMessage)
323 .where(ArchivedMessage.room_id == room_pk)
324 .where(ArchivedMessage.stanza_id == message.id)
325 ).scalar()
326 if existing is None and legacy_msg_id is not None:
327 existing = session.execute(
328 select(ArchivedMessage)
329 .where(ArchivedMessage.room_id == room_pk)
330 .where(ArchivedMessage.legacy_id == legacy_msg_id)
331 ).scalar()
332 if existing is not None:
333 log.debug("Updating message %s in room %s", message.id, room_pk)
334 existing.timestamp = message.when
335 existing.stanza = str(message.stanza)
336 existing.author_jid = message.stanza.get_from()
337 existing.source = source
338 existing.legacy_id = legacy_msg_id
339 session.add(existing)
340 return
341 mam_msg = ArchivedMessage(
342 stanza_id=message.id,
343 timestamp=message.when,
344 stanza=str(message.stanza),
345 author_jid=message.stanza.get_from(),
346 room_id=room_pk,
347 source=source,
348 legacy_id=legacy_msg_id,
349 )
350 session.add(mam_msg)
352 @staticmethod
353 def get_messages(
354 session: Session,
355 room_pk: int,
356 start_date: Optional[datetime] = None,
357 end_date: Optional[datetime] = None,
358 before_id: Optional[str] = None,
359 after_id: Optional[str] = None,
360 ids: Collection[str] = (),
361 last_page_n: Optional[int] = None,
362 sender: Optional[str] = None,
363 flip: bool = False,
364 ) -> Iterator[HistoryMessage]:
365 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
366 if start_date is not None:
367 q = q.where(ArchivedMessage.timestamp >= start_date)
368 if end_date is not None:
369 q = q.where(ArchivedMessage.timestamp <= end_date)
370 if before_id is not None:
371 stamp = session.execute(
372 select(ArchivedMessage.timestamp).where(
373 ArchivedMessage.stanza_id == before_id,
374 ArchivedMessage.room_id == room_pk,
375 )
376 ).scalar_one_or_none()
377 if stamp is None:
378 raise XMPPError(
379 "item-not-found",
380 f"Message {before_id} not found",
381 )
382 q = q.where(ArchivedMessage.timestamp < stamp)
383 if after_id is not None:
384 stamp = session.execute(
385 select(ArchivedMessage.timestamp).where(
386 ArchivedMessage.stanza_id == after_id,
387 ArchivedMessage.room_id == room_pk,
388 )
389 ).scalar_one_or_none()
390 if stamp is None:
391 raise XMPPError(
392 "item-not-found",
393 f"Message {after_id} not found",
394 )
395 q = q.where(ArchivedMessage.timestamp > stamp)
396 if ids:
397 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
398 if sender is not None:
399 q = q.where(ArchivedMessage.author_jid == sender)
400 if flip:
401 q = q.order_by(ArchivedMessage.timestamp.desc())
402 else:
403 q = q.order_by(ArchivedMessage.timestamp.asc())
404 msgs = list(session.execute(q).scalars())
405 if ids and len(msgs) != len(ids):
406 raise XMPPError(
407 "item-not-found",
408 "One of the requested messages IDs could not be found "
409 "with the given constraints.",
410 )
411 if last_page_n is not None:
412 if flip:
413 msgs = msgs[:last_page_n]
414 else:
415 msgs = msgs[-last_page_n:]
416 for h in msgs:
417 yield HistoryMessage(
418 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
419 )
421 @staticmethod
422 def get_first(
423 session: Session, room_pk: int, with_legacy_id: bool = False
424 ) -> Optional[ArchivedMessage]:
425 q = (
426 select(ArchivedMessage)
427 .where(ArchivedMessage.room_id == room_pk)
428 .order_by(ArchivedMessage.timestamp.asc())
429 )
430 if with_legacy_id:
431 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
432 return session.execute(q).scalar()
434 @staticmethod
435 def get_last(
436 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
437 ) -> Optional[ArchivedMessage]:
438 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
440 if source is not None:
441 q = q.where(ArchivedMessage.source == source)
443 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
445 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
446 r = []
447 first = self.get_first(session, room_pk)
448 if first is not None:
449 r.append(MamMetadata(first.stanza_id, first.timestamp))
450 last = self.get_last(session, room_pk)
451 if last is not None:
452 r.append(MamMetadata(last.stanza_id, last.timestamp))
453 return r
455 @staticmethod
456 def get_most_recent_with_legacy_id(
457 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
458 ) -> Optional[ArchivedMessage]:
459 q = (
460 select(ArchivedMessage)
461 .where(ArchivedMessage.room_id == room_pk)
462 .where(ArchivedMessage.legacy_id.isnot(None))
463 )
464 if source is not None:
465 q = q.where(ArchivedMessage.source == source)
466 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
468 @staticmethod
469 def get_least_recent_with_legacy_id_after(
470 session: Session,
471 room_pk: int,
472 after_id: str,
473 source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
474 ) -> Optional[ArchivedMessage]:
475 after_timestamp = (
476 session.query(ArchivedMessage.timestamp)
477 .filter(ArchivedMessage.room_id == room_pk)
478 .filter(ArchivedMessage.legacy_id == after_id)
479 .scalar()
480 )
481 q = (
482 select(ArchivedMessage)
483 .where(ArchivedMessage.room_id == room_pk)
484 .where(ArchivedMessage.legacy_id.isnot(None))
485 .where(ArchivedMessage.source == source)
486 .where(ArchivedMessage.timestamp > after_timestamp)
487 )
488 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
490 @staticmethod
491 def get_by_legacy_id(
492 session: Session, room_pk: int, legacy_id: str
493 ) -> Optional[ArchivedMessage]:
494 return (
495 session.query(ArchivedMessage)
496 .filter(ArchivedMessage.room_id == room_pk)
497 .filter(ArchivedMessage.legacy_id == legacy_id)
498 .first()
499 )
501 @staticmethod
502 def pop_unread_up_to(session: Session, room_pk: int, stanza_id: str) -> list[str]:
503 q = (
504 select(ArchivedMessage.id, ArchivedMessage.stanza_id)
505 .where(ArchivedMessage.room_id == room_pk)
506 .where(~ArchivedMessage.displayed_by_user)
507 .where(ArchivedMessage.legacy_id.is_not(None))
508 .order_by(ArchivedMessage.timestamp.asc())
509 )
511 ref = session.scalar(
512 select(ArchivedMessage)
513 .where(ArchivedMessage.room_id == room_pk)
514 .where(ArchivedMessage.stanza_id == stanza_id)
515 )
517 if ref is None:
518 log.debug(
519 "(pop unread in muc): message not found, returning all MAM messages."
520 )
521 rows = session.execute(q)
522 else:
523 rows = session.execute(q.where(ArchivedMessage.timestamp <= ref.timestamp))
525 pks: list[int] = []
526 stanza_ids: list[str] = []
528 for id_, stanza_id in rows:
529 pks.append(id_)
530 stanza_ids.append(stanza_id)
532 session.execute(
533 update(ArchivedMessage)
534 .where(ArchivedMessage.id.in_(pks))
535 .values(displayed_by_user=True)
536 )
537 return stanza_ids
539 @staticmethod
540 def is_displayed_by_user(
541 session: Session, room_jid: str, legacy_msg_id: str
542 ) -> bool:
543 return any(
544 session.execute(
545 select(ArchivedMessage.displayed_by_user)
546 .join(Room)
547 .where(Room.jid == room_jid)
548 .where(ArchivedMessage.legacy_id == legacy_msg_id)
549 ).scalars()
550 )
553class RoomStore(UpdatedMixin):
554 model = Room
556 def reset_updated(self, session: Session) -> None:
557 super().reset_updated(session)
558 session.execute(
559 update(Room).values(
560 subject_setter=None,
561 user_resources=None,
562 history_filled=False,
563 participants_filled=False,
564 )
565 )
567 @staticmethod
568 def get_all(session: Session, user_pk: int) -> Iterator[Room]:
569 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
571 @staticmethod
572 def get(session: Session, user_pk: int, legacy_id: str) -> Room:
573 return session.execute(
574 select(Room)
575 .where(Room.user_account_id == user_pk)
576 .where(Room.legacy_id == legacy_id)
577 ).scalar_one()
579 @staticmethod
580 def nick_available(session: Session, room_pk: int, nickname: str) -> bool:
581 return (
582 session.execute(
583 select(Participant.id).filter_by(room_id=room_pk, nickname=nickname)
584 )
585 ).one_or_none() is None
588class ParticipantStore:
589 def __init__(self, session: Session) -> None:
590 session.execute(delete(Participant))
592 @staticmethod
593 def get_all(
594 session: Session, room_pk: int, user_included: bool = True
595 ) -> Iterator[Participant]:
596 query = select(Participant).where(Participant.room_id == room_pk)
597 if not user_included:
598 query = query.where(~Participant.is_user)
599 yield from session.scalars(query).unique()
602class BobStore:
603 _ATTR_MAP = {
604 "sha-1": "sha_1",
605 "sha1": "sha_1",
606 "sha-256": "sha_256",
607 "sha256": "sha_256",
608 "sha-512": "sha_512",
609 "sha512": "sha_512",
610 }
612 _ALG_MAP = {
613 "sha_1": hashlib.sha1,
614 "sha_256": hashlib.sha256,
615 "sha_512": hashlib.sha512,
616 }
618 def __init__(self) -> None:
619 if (config.HOME_DIR / "slidge_stickers").exists():
620 shutil.move(
621 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store"
622 )
623 self.root_dir = config.HOME_DIR / "bob_store"
624 self.root_dir.mkdir(exist_ok=True)
626 @staticmethod
627 def __split_cid(cid: str) -> list[str]:
628 return cid.removesuffix("@bob.xmpp.org").split("+")
630 def __get_condition(self, cid: str):
631 alg_name, digest = self.__split_cid(cid)
632 attr = self._ATTR_MAP.get(alg_name)
633 if attr is None:
634 log.warning("Unknown hash algorithm: %s", alg_name)
635 return None
636 return getattr(Bob, attr) == digest
638 def get(self, session: Session, cid: str) -> Bob | None:
639 try:
640 return session.query(Bob).filter(self.__get_condition(cid)).scalar()
641 except ValueError:
642 log.warning("Cannot get Bob with CID: %s", cid)
643 return None
645 def get_sticker(self, session: Session, cid: str) -> Sticker | None:
646 bob = self.get(session, cid)
647 if bob is None:
648 return None
649 return Sticker(
650 self.root_dir / bob.file_name,
651 bob.content_type,
652 {h: getattr(bob, h) for h in self._ALG_MAP},
653 )
655 def get_bob(
656 self, session: Session, _jid, _node, _ifrom, cid: str
657 ) -> BitsOfBinary | None:
658 stored = self.get(session, cid)
659 if stored is None:
660 return None
661 bob = BitsOfBinary()
662 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
663 if stored.content_type is not None:
664 bob["type"] = stored.content_type
665 bob["cid"] = cid
666 return bob
668 def del_bob(self, session: Session, _jid, _node, _ifrom, cid: str) -> None:
669 try:
670 file_name = session.scalar(
671 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
672 )
673 except ValueError:
674 log.warning("Cannot delete Bob with CID: %s", cid)
675 return None
676 if file_name is None:
677 log.warning("No BoB with CID: %s", cid)
678 return None
679 (self.root_dir / file_name).unlink()
681 def set_bob(self, session: Session, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
682 cid = bob["cid"]
683 try:
684 alg_name, digest = self.__split_cid(cid)
685 except ValueError:
686 log.warning("Invalid CID provided: %s", cid)
687 return
688 attr = self._ATTR_MAP.get(alg_name)
689 if attr is None:
690 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
691 return
692 existing = self.get(session, bob["cid"])
693 if existing:
694 log.debug("Bob already exists")
695 return
696 bytes_ = bob["data"]
697 path = self.root_dir / uuid.uuid4().hex
698 if bob["type"]:
699 path = path.with_suffix(guess_extension(bob["type"]) or "")
700 path.write_bytes(bytes_)
701 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
702 if hashes[attr] != digest:
703 path.unlink(missing_ok=True)
704 raise ValueError("Provided CID does not match calculated hash")
705 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
706 session.add(row)
709@event.listens_for(sa.orm.Session, "after_flush")
710def _check_avatar_orphans(session, flush_context):
711 if not session.deleted:
712 return
714 potentially_orphaned = set()
715 for obj in session.deleted:
716 if isinstance(obj, (Contact, Room)) and obj.avatar_id:
717 potentially_orphaned.add(obj.avatar_id)
718 if not potentially_orphaned:
719 return
721 result = session.execute(
722 sa.delete(Avatar).where(
723 sa.and_(
724 Avatar.id.in_(potentially_orphaned),
725 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)),
726 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)),
727 )
728 )
729 )
730 deleted_count = result.rowcount
731 log.debug("Auto-deleted %s orphaned avatars", deleted_count)
734log = logging.getLogger(__name__)