Coverage for slidge/db/store.py: 88%
315 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-26 19:34 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-26 19:34 +0000
1from __future__ import annotations
3import hashlib
4import logging
5import shutil
6import uuid
7from datetime import datetime, timedelta, timezone
8from mimetypes import guess_extension
9from typing import Collection, Iterator, Optional, Type
11import sqlalchemy as sa
12from slixmpp.exceptions import XMPPError
13from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
14from sqlalchemy import Engine, delete, event, select, update
15from sqlalchemy.exc import InvalidRequestError
16from sqlalchemy.orm import Session, attributes, sessionmaker
18from ..core import config
19from ..util.archive_msg import HistoryMessage
20from ..util.types import MamMetadata, Sticker
21from .meta import Base
22from .models import (
23 ArchivedMessage,
24 ArchivedMessageSource,
25 Avatar,
26 Bob,
27 Contact,
28 ContactSent,
29 DirectMessages,
30 DirectThreads,
31 GatewayUser,
32 GroupMessages,
33 GroupMessagesOrigin,
34 GroupThreads,
35 Participant,
36 Room,
37)
40class UpdatedMixin:
41 model: Type[Base] = NotImplemented
43 def __init__(self, session: Session) -> None:
44 session.execute(update(self.model).values(updated=False))
46 def get_by_pk(self, session: Session, pk: int) -> Type[Base]:
47 stmt = select(self.model).where(self.model.id == pk) # type:ignore
48 return session.scalar(stmt)
51class SlidgeStore:
52 def __init__(self, engine: Engine) -> None:
53 self._engine = engine
54 self.session = sessionmaker(engine)
56 self.users = UserStore(self.session)
57 self.avatars = AvatarStore(self.session)
58 self.id_map = IdMapStore()
59 self.bob = BobStore()
60 with self.session() as session:
61 self.contacts = ContactStore(session)
62 self.mam = MAMStore(session, self.session)
63 self.rooms = RoomStore(session)
64 self.participants = ParticipantStore(session)
65 session.commit()
68class UserStore:
69 def __init__(self, session_maker) -> None:
70 self.session = session_maker
72 def update(self, user: GatewayUser) -> None:
73 with self.session(expire_on_commit=False) as session:
74 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
75 try:
76 attributes.flag_modified(user, "legacy_module_data")
77 attributes.flag_modified(user, "preferences")
78 except InvalidRequestError:
79 pass
80 session.add(user)
81 session.commit()
84class AvatarStore:
85 def __init__(self, session_maker) -> None:
86 self.session = session_maker
89LegacyToXmppType = (
90 Type[DirectMessages]
91 | Type[DirectThreads]
92 | Type[GroupMessages]
93 | Type[GroupThreads]
94 | Type[GroupMessagesOrigin]
95)
98class IdMapStore:
99 @staticmethod
100 def _set(
101 session: Session,
102 foreign_key: int,
103 legacy_id: str,
104 xmpp_ids: list[str],
105 type_: LegacyToXmppType,
106 ) -> None:
107 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
108 ids = session.scalars(
109 select(type_.id).filter(
110 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
111 )
112 )
113 if ids:
114 log.debug("Resetting legacy ID %s", legacy_id)
115 session.execute(delete(type_).where(type_.id.in_(ids)))
116 for xmpp_id in xmpp_ids:
117 msg = type_(xmpp_id=xmpp_id, **kwargs)
118 session.add(msg)
120 def set_thread(
121 self,
122 session: Session,
123 foreign_key: int,
124 legacy_id: str,
125 xmpp_id: str,
126 group: bool,
127 ) -> None:
128 self._set(
129 session,
130 foreign_key,
131 legacy_id,
132 [xmpp_id],
133 GroupThreads if group else DirectThreads,
134 )
136 def set_msg(
137 self,
138 session: Session,
139 foreign_key: int,
140 legacy_id: str,
141 xmpp_ids: list[str],
142 group: bool,
143 ) -> None:
144 self._set(
145 session,
146 foreign_key,
147 legacy_id,
148 xmpp_ids,
149 GroupMessages if group else DirectMessages,
150 )
152 def set_origin(
153 self, session: Session, foreign_key: int, legacy_id: str, xmpp_id: str
154 ) -> None:
155 self._set(
156 session,
157 foreign_key,
158 legacy_id,
159 [xmpp_id],
160 GroupMessagesOrigin,
161 )
163 def get_origin(
164 self, session: Session, foreign_key: int, legacy_id: str
165 ) -> list[str]:
166 return self._get(
167 session,
168 foreign_key,
169 legacy_id,
170 GroupMessagesOrigin,
171 )
173 @staticmethod
174 def _get(
175 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
176 ) -> list[str]:
177 return list(
178 session.scalars(
179 select(type_.xmpp_id).filter_by(
180 foreign_key=foreign_key, legacy_id=legacy_id
181 )
182 )
183 )
185 def get_xmpp(
186 self, session: Session, foreign_key: int, legacy_id: str, group: bool
187 ) -> list[str]:
188 return self._get(
189 session,
190 foreign_key,
191 legacy_id,
192 GroupMessages if group else DirectMessages,
193 )
195 @staticmethod
196 def _get_legacy(
197 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
198 ) -> Optional[str]:
199 return session.scalar(
200 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
201 )
203 def get_legacy(
204 self,
205 session: Session,
206 foreign_key: int,
207 xmpp_id: str,
208 group: bool,
209 origin: bool = False,
210 ) -> Optional[str]:
211 if origin and group:
212 return self._get_legacy(
213 session,
214 foreign_key,
215 xmpp_id,
216 GroupMessagesOrigin,
217 )
218 return self._get_legacy(
219 session,
220 foreign_key,
221 xmpp_id,
222 GroupMessages if group else DirectMessages,
223 )
225 def get_thread(
226 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
227 ) -> Optional[str]:
228 return self._get_legacy(
229 session,
230 foreign_key,
231 xmpp_id,
232 GroupThreads if group else DirectThreads,
233 )
235 @staticmethod
236 def was_sent_by_user(
237 session: Session, foreign_key: int, legacy_id: str, group: bool
238 ) -> bool:
239 type_ = GroupMessages if group else DirectMessages
240 return (
241 session.scalar(
242 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
243 )
244 is not None
245 )
248class ContactStore(UpdatedMixin):
249 model = Contact
251 def __init__(self, session: Session) -> None:
252 super().__init__(session)
253 session.execute(update(Contact).values(cached_presence=False))
254 session.execute(update(Contact).values(caps_ver=None))
256 @staticmethod
257 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
258 if (
259 session.query(ContactSent.id)
260 .where(ContactSent.contact_id == contact_pk)
261 .where(ContactSent.msg_id == msg_id)
262 .first()
263 ) is not None:
264 log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
265 return
266 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
267 session.add(new)
269 @staticmethod
270 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
271 result = []
272 to_del = []
273 for row in session.execute(
274 select(ContactSent)
275 .where(ContactSent.contact_id == contact_pk)
276 .order_by(ContactSent.id)
277 ).scalars():
278 to_del.append(row.id)
279 result.append(row.msg_id)
280 if row.msg_id == msg_id:
281 break
282 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
283 return result
286class MAMStore:
287 def __init__(self, session: Session, session_maker) -> None:
288 self.session = session_maker
289 session.execute(
290 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
291 )
293 @staticmethod
294 def nuke_older_than(session: Session, days: int) -> None:
295 session.execute(
296 delete(ArchivedMessage).where(
297 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
298 )
299 )
301 @staticmethod
302 def add_message(
303 session: Session,
304 room_pk: int,
305 message: HistoryMessage,
306 archive_only: bool,
307 legacy_msg_id: Optional[str],
308 ) -> None:
309 source = (
310 ArchivedMessageSource.BACKFILL
311 if archive_only
312 else ArchivedMessageSource.LIVE
313 )
314 existing = session.execute(
315 select(ArchivedMessage)
316 .where(ArchivedMessage.room_id == room_pk)
317 .where(ArchivedMessage.stanza_id == message.id)
318 ).scalar()
319 if existing is None and legacy_msg_id is not None:
320 existing = session.execute(
321 select(ArchivedMessage)
322 .where(ArchivedMessage.room_id == room_pk)
323 .where(ArchivedMessage.legacy_id == legacy_msg_id)
324 ).scalar()
325 if existing is not None:
326 log.debug("Updating message %s in room %s", message.id, room_pk)
327 existing.timestamp = message.when
328 existing.stanza = str(message.stanza)
329 existing.author_jid = message.stanza.get_from()
330 existing.source = source
331 existing.legacy_id = legacy_msg_id
332 session.add(existing)
333 return
334 mam_msg = ArchivedMessage(
335 stanza_id=message.id,
336 timestamp=message.when,
337 stanza=str(message.stanza),
338 author_jid=message.stanza.get_from(),
339 room_id=room_pk,
340 source=source,
341 legacy_id=legacy_msg_id,
342 )
343 session.add(mam_msg)
345 @staticmethod
346 def get_messages(
347 session: Session,
348 room_pk: int,
349 start_date: Optional[datetime] = None,
350 end_date: Optional[datetime] = None,
351 before_id: Optional[str] = None,
352 after_id: Optional[str] = None,
353 ids: Collection[str] = (),
354 last_page_n: Optional[int] = None,
355 sender: Optional[str] = None,
356 flip: bool = False,
357 ) -> Iterator[HistoryMessage]:
358 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
359 if start_date is not None:
360 q = q.where(ArchivedMessage.timestamp >= start_date)
361 if end_date is not None:
362 q = q.where(ArchivedMessage.timestamp <= end_date)
363 if before_id is not None:
364 stamp = session.execute(
365 select(ArchivedMessage.timestamp).where(
366 ArchivedMessage.stanza_id == before_id,
367 ArchivedMessage.room_id == room_pk,
368 )
369 ).scalar_one_or_none()
370 if stamp is None:
371 raise XMPPError(
372 "item-not-found",
373 f"Message {before_id} not found",
374 )
375 q = q.where(ArchivedMessage.timestamp < stamp)
376 if after_id is not None:
377 stamp = session.execute(
378 select(ArchivedMessage.timestamp).where(
379 ArchivedMessage.stanza_id == after_id,
380 ArchivedMessage.room_id == room_pk,
381 )
382 ).scalar_one_or_none()
383 if stamp is None:
384 raise XMPPError(
385 "item-not-found",
386 f"Message {after_id} not found",
387 )
388 q = q.where(ArchivedMessage.timestamp > stamp)
389 if ids:
390 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
391 if sender is not None:
392 q = q.where(ArchivedMessage.author_jid == sender)
393 if flip:
394 q = q.order_by(ArchivedMessage.timestamp.desc())
395 else:
396 q = q.order_by(ArchivedMessage.timestamp.asc())
397 msgs = list(session.execute(q).scalars())
398 if ids and len(msgs) != len(ids):
399 raise XMPPError(
400 "item-not-found",
401 "One of the requested messages IDs could not be found "
402 "with the given constraints.",
403 )
404 if last_page_n is not None:
405 if flip:
406 msgs = msgs[:last_page_n]
407 else:
408 msgs = msgs[-last_page_n:]
409 for h in msgs:
410 yield HistoryMessage(
411 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
412 )
414 @staticmethod
415 def get_first(
416 session: Session, room_pk: int, with_legacy_id: bool = False
417 ) -> Optional[ArchivedMessage]:
418 q = (
419 select(ArchivedMessage)
420 .where(ArchivedMessage.room_id == room_pk)
421 .order_by(ArchivedMessage.timestamp.asc())
422 )
423 if with_legacy_id:
424 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
425 return session.execute(q).scalar()
427 @staticmethod
428 def get_last(
429 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
430 ) -> Optional[ArchivedMessage]:
431 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
433 if source is not None:
434 q = q.where(ArchivedMessage.source == source)
436 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
438 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
439 r = []
440 first = self.get_first(session, room_pk)
441 if first is not None:
442 r.append(MamMetadata(first.stanza_id, first.timestamp))
443 last = self.get_last(session, room_pk)
444 if last is not None:
445 r.append(MamMetadata(last.stanza_id, last.timestamp))
446 return r
448 @staticmethod
449 def get_most_recent_with_legacy_id(
450 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
451 ) -> Optional[ArchivedMessage]:
452 q = (
453 select(ArchivedMessage)
454 .where(ArchivedMessage.room_id == room_pk)
455 .where(ArchivedMessage.legacy_id.isnot(None))
456 )
457 if source is not None:
458 q = q.where(ArchivedMessage.source == source)
459 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
461 @staticmethod
462 def get_least_recent_with_legacy_id_after(
463 session: Session,
464 room_pk: int,
465 after_id: str,
466 source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
467 ) -> Optional[ArchivedMessage]:
468 after_timestamp = (
469 session.query(ArchivedMessage.timestamp)
470 .filter(ArchivedMessage.room_id == room_pk)
471 .filter(ArchivedMessage.legacy_id == after_id)
472 .scalar()
473 )
474 q = (
475 select(ArchivedMessage)
476 .where(ArchivedMessage.room_id == room_pk)
477 .where(ArchivedMessage.legacy_id.isnot(None))
478 .where(ArchivedMessage.source == source)
479 .where(ArchivedMessage.timestamp > after_timestamp)
480 )
481 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
483 @staticmethod
484 def get_by_legacy_id(
485 session: Session, room_pk: int, legacy_id: str
486 ) -> Optional[ArchivedMessage]:
487 return (
488 session.query(ArchivedMessage)
489 .filter(ArchivedMessage.room_id == room_pk)
490 .filter(ArchivedMessage.legacy_id == legacy_id)
491 .first()
492 )
495class RoomStore(UpdatedMixin):
496 model = Room
498 def __init__(self, session: Session) -> None:
499 super().__init__(session)
500 session.execute(
501 update(Room).values(
502 subject_setter=None,
503 user_resources=None,
504 history_filled=False,
505 participants_filled=False,
506 )
507 )
509 @staticmethod
510 def get_all(session: Session, user_pk: int) -> Iterator[Room]:
511 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
514class ParticipantStore:
515 def __init__(self, session: Session) -> None:
516 session.execute(delete(Participant))
518 @staticmethod
519 def get_all(
520 session: Session, room_pk: int, user_included: bool = True
521 ) -> Iterator[Participant]:
522 query = select(Participant).where(Participant.room_id == room_pk)
523 if not user_included:
524 query = query.where(~Participant.is_user)
525 yield from session.scalars(query).unique()
528class BobStore:
529 _ATTR_MAP = {
530 "sha-1": "sha_1",
531 "sha1": "sha_1",
532 "sha-256": "sha_256",
533 "sha256": "sha_256",
534 "sha-512": "sha_512",
535 "sha512": "sha_512",
536 }
538 _ALG_MAP = {
539 "sha_1": hashlib.sha1,
540 "sha_256": hashlib.sha256,
541 "sha_512": hashlib.sha512,
542 }
544 def __init__(self) -> None:
545 if (config.HOME_DIR / "slidge_stickers").exists():
546 shutil.move(
547 config.HOME_DIR / "slidge_stickers", config.HOME_DIR / "bob_store"
548 )
549 self.root_dir = config.HOME_DIR / "bob_store"
550 self.root_dir.mkdir(exist_ok=True)
552 @staticmethod
553 def __split_cid(cid: str) -> list[str]:
554 return cid.removesuffix("@bob.xmpp.org").split("+")
556 def __get_condition(self, cid: str):
557 alg_name, digest = self.__split_cid(cid)
558 attr = self._ATTR_MAP.get(alg_name)
559 if attr is None:
560 log.warning("Unknown hash algorithm: %s", alg_name)
561 return None
562 return getattr(Bob, attr) == digest
564 def get(self, session: Session, cid: str) -> Bob | None:
565 try:
566 return session.query(Bob).filter(self.__get_condition(cid)).scalar()
567 except ValueError:
568 log.warning("Cannot get Bob with CID: %s", cid)
569 return None
571 def get_sticker(self, session: Session, cid: str) -> Sticker | None:
572 bob = self.get(session, cid)
573 if bob is None:
574 return None
575 return Sticker(
576 self.root_dir / bob.file_name,
577 bob.content_type,
578 {h: getattr(bob, h) for h in self._ALG_MAP},
579 )
581 def get_bob(
582 self, session: Session, _jid, _node, _ifrom, cid: str
583 ) -> BitsOfBinary | None:
584 stored = self.get(session, cid)
585 if stored is None:
586 return None
587 bob = BitsOfBinary()
588 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
589 if stored.content_type is not None:
590 bob["type"] = stored.content_type
591 bob["cid"] = cid
592 return bob
594 def del_bob(self, session: Session, _jid, _node, _ifrom, cid: str) -> None:
595 try:
596 file_name = session.scalar(
597 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
598 )
599 except ValueError:
600 log.warning("Cannot delete Bob with CID: %s", cid)
601 return None
602 if file_name is None:
603 log.warning("No BoB with CID: %s", cid)
604 return None
605 (self.root_dir / file_name).unlink()
607 def set_bob(self, session: Session, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
608 cid = bob["cid"]
609 try:
610 alg_name, digest = self.__split_cid(cid)
611 except ValueError:
612 log.warning("Invalid CID provided: %s", cid)
613 return
614 attr = self._ATTR_MAP.get(alg_name)
615 if attr is None:
616 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
617 return
618 existing = self.get(session, bob["cid"])
619 if existing:
620 log.debug("Bob already exists")
621 return
622 bytes_ = bob["data"]
623 path = self.root_dir / uuid.uuid4().hex
624 if bob["type"]:
625 path = path.with_suffix(guess_extension(bob["type"]) or "")
626 path.write_bytes(bytes_)
627 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
628 if hashes[attr] != digest:
629 path.unlink(missing_ok=True)
630 raise ValueError("Provided CID does not match calculated hash")
631 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
632 session.add(row)
635@event.listens_for(sa.orm.Session, "after_flush")
636def _check_avatar_orphans(session, flush_context):
637 if not session.deleted:
638 return
640 potentially_orphaned = set()
641 for obj in session.deleted:
642 if isinstance(obj, (Contact, Room)) and obj.avatar_id:
643 potentially_orphaned.add(obj.avatar_id)
644 if not potentially_orphaned:
645 return
647 result = session.execute(
648 sa.delete(Avatar).where(
649 sa.and_(
650 Avatar.id.in_(potentially_orphaned),
651 sa.not_(sa.exists().where(Contact.avatar_id == Avatar.id)),
652 sa.not_(sa.exists().where(Room.avatar_id == Avatar.id)),
653 )
654 )
655 )
656 deleted_count = result.rowcount
657 log.debug(f"Auto-deleted %s orphaned avatars", deleted_count)
660log = logging.getLogger(__name__)