Coverage for slidge/db/store.py: 87%
291 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1from __future__ import annotations
3import hashlib
4import logging
5import uuid
6from datetime import datetime, timedelta, timezone
7from mimetypes import guess_extension
8from typing import Collection, Iterator, Optional, Type
10from slixmpp.exceptions import XMPPError
11from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
12from sqlalchemy import Engine, delete, select, update
13from sqlalchemy.exc import InvalidRequestError
14from sqlalchemy.orm import Session, attributes, sessionmaker
16from ..core import config
17from ..util.archive_msg import HistoryMessage
18from ..util.types import MamMetadata, Sticker
19from .meta import Base
20from .models import (
21 ArchivedMessage,
22 ArchivedMessageSource,
23 Bob,
24 Contact,
25 ContactSent,
26 DirectMessages,
27 DirectThreads,
28 GatewayUser,
29 GroupMessages,
30 GroupThreads,
31 Participant,
32 Room,
33)
36class UpdatedMixin:
37 model: Type[Base] = NotImplemented
39 def __init__(self, session: Session) -> None:
40 session.execute(update(self.model).values(updated=False))
42 def get_by_pk(self, session: Session, pk: int) -> Type[Base]:
43 stmt = select(self.model).where(self.model.id == pk) # type:ignore
44 return session.scalar(stmt)
47class SlidgeStore:
48 def __init__(self, engine: Engine) -> None:
49 self._engine = engine
50 self.session = sessionmaker(engine)
52 self.users = UserStore(self.session)
53 self.avatars = AvatarStore(self.session)
54 self.id_map = IdMapStore()
55 self.bob = BobStore()
56 with self.session() as session:
57 self.contacts = ContactStore(session)
58 self.mam = MAMStore(session, self.session)
59 self.rooms = RoomStore(session)
60 self.participants = ParticipantStore(session)
61 session.commit()
64class UserStore:
65 def __init__(self, session_maker) -> None:
66 self.session = session_maker
68 def update(self, user: GatewayUser) -> None:
69 with self.session(expire_on_commit=False) as session:
70 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
71 try:
72 attributes.flag_modified(user, "legacy_module_data")
73 attributes.flag_modified(user, "preferences")
74 except InvalidRequestError:
75 pass
76 session.add(user)
77 session.commit()
80class AvatarStore:
81 def __init__(self, session_maker) -> None:
82 self.session = session_maker
85LegacyToXmppType = (
86 Type[DirectMessages]
87 | Type[DirectThreads]
88 | Type[GroupMessages]
89 | Type[GroupThreads]
90)
93class IdMapStore:
94 @staticmethod
95 def _set(
96 session: Session,
97 foreign_key: int,
98 legacy_id: str,
99 xmpp_ids: list[str],
100 type_: LegacyToXmppType,
101 ) -> None:
102 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id)
103 ids = session.scalars(
104 select(type_.id).filter(
105 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id
106 )
107 )
108 if ids:
109 log.debug("Resetting legacy ID %s", legacy_id)
110 session.execute(delete(type_).where(type_.id.in_(ids)))
111 for xmpp_id in xmpp_ids:
112 msg = type_(xmpp_id=xmpp_id, **kwargs)
113 session.add(msg)
115 def set_thread(
116 self,
117 session: Session,
118 foreign_key: int,
119 legacy_id: str,
120 xmpp_id: str,
121 group: bool,
122 ) -> None:
123 self._set(
124 session,
125 foreign_key,
126 legacy_id,
127 [xmpp_id],
128 GroupThreads if group else DirectThreads,
129 )
131 def set_msg(
132 self,
133 session: Session,
134 foreign_key: int,
135 legacy_id: str,
136 xmpp_ids: list[str],
137 group: bool,
138 ) -> None:
139 self._set(
140 session,
141 foreign_key,
142 legacy_id,
143 xmpp_ids,
144 GroupMessages if group else DirectMessages,
145 )
147 @staticmethod
148 def _get(
149 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType
150 ) -> list[str]:
151 return list(
152 session.scalars(
153 select(type_.xmpp_id).filter_by(
154 foreign_key=foreign_key, legacy_id=legacy_id
155 )
156 )
157 )
159 def get_xmpp(
160 self, session: Session, foreign_key: int, legacy_id: str, group: bool
161 ) -> list[str]:
162 return self._get(
163 session,
164 foreign_key,
165 legacy_id,
166 GroupMessages if group else DirectMessages,
167 )
169 @staticmethod
170 def _get_legacy(
171 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType
172 ) -> Optional[str]:
173 return session.scalar(
174 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id)
175 )
177 def get_legacy(
178 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
179 ) -> Optional[str]:
180 return self._get_legacy(
181 session,
182 foreign_key,
183 xmpp_id,
184 GroupMessages if group else DirectMessages,
185 )
187 def get_thread(
188 self, session: Session, foreign_key: int, xmpp_id: str, group: bool
189 ) -> Optional[str]:
190 return self._get_legacy(
191 session,
192 foreign_key,
193 xmpp_id,
194 GroupThreads if group else DirectThreads,
195 )
197 @staticmethod
198 def was_sent_by_user(
199 session: Session, foreign_key: int, legacy_id: str, group: bool
200 ) -> bool:
201 type_ = GroupMessages if group else DirectMessages
202 return (
203 session.scalar(
204 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id)
205 )
206 is not None
207 )
210class ContactStore(UpdatedMixin):
211 model = Contact
213 def __init__(self, session: Session) -> None:
214 super().__init__(session)
215 session.execute(update(Contact).values(cached_presence=False))
217 @staticmethod
218 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None:
219 if (
220 session.query(ContactSent.id)
221 .where(ContactSent.contact_id == contact_pk)
222 .where(ContactSent.msg_id == msg_id)
223 .first()
224 ) is not None:
225 log.warning("Contact %s has already sent message %s", contact_pk, msg_id)
226 return
227 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
228 session.add(new)
230 @staticmethod
231 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]:
232 result = []
233 to_del = []
234 for row in session.execute(
235 select(ContactSent)
236 .where(ContactSent.contact_id == contact_pk)
237 .order_by(ContactSent.id)
238 ).scalars():
239 to_del.append(row.id)
240 result.append(row.msg_id)
241 if row.msg_id == msg_id:
242 break
243 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del)))
244 return result
247class MAMStore:
248 def __init__(self, session: Session, session_maker) -> None:
249 self.session = session_maker
250 session.execute(
251 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
252 )
254 @staticmethod
255 def nuke_older_than(session: Session, days: int) -> None:
256 session.execute(
257 delete(ArchivedMessage).where(
258 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
259 )
260 )
262 @staticmethod
263 def add_message(
264 session: Session,
265 room_pk: int,
266 message: HistoryMessage,
267 archive_only: bool,
268 legacy_msg_id: Optional[str],
269 ) -> None:
270 source = (
271 ArchivedMessageSource.BACKFILL
272 if archive_only
273 else ArchivedMessageSource.LIVE
274 )
275 existing = session.execute(
276 select(ArchivedMessage)
277 .where(ArchivedMessage.room_id == room_pk)
278 .where(ArchivedMessage.stanza_id == message.id)
279 ).scalar()
280 if existing is None and legacy_msg_id is not None:
281 existing = session.execute(
282 select(ArchivedMessage)
283 .where(ArchivedMessage.room_id == room_pk)
284 .where(ArchivedMessage.legacy_id == legacy_msg_id)
285 ).scalar()
286 if existing is not None:
287 log.debug("Updating message %s in room %s", message.id, room_pk)
288 existing.timestamp = message.when
289 existing.stanza = str(message.stanza)
290 existing.author_jid = message.stanza.get_from()
291 existing.source = source
292 existing.legacy_id = legacy_msg_id
293 session.add(existing)
294 return
295 mam_msg = ArchivedMessage(
296 stanza_id=message.id,
297 timestamp=message.when,
298 stanza=str(message.stanza),
299 author_jid=message.stanza.get_from(),
300 room_id=room_pk,
301 source=source,
302 legacy_id=legacy_msg_id,
303 )
304 session.add(mam_msg)
306 @staticmethod
307 def get_messages(
308 session: Session,
309 room_pk: int,
310 start_date: Optional[datetime] = None,
311 end_date: Optional[datetime] = None,
312 before_id: Optional[str] = None,
313 after_id: Optional[str] = None,
314 ids: Collection[str] = (),
315 last_page_n: Optional[int] = None,
316 sender: Optional[str] = None,
317 flip: bool = False,
318 ) -> Iterator[HistoryMessage]:
319 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
320 if start_date is not None:
321 q = q.where(ArchivedMessage.timestamp >= start_date)
322 if end_date is not None:
323 q = q.where(ArchivedMessage.timestamp <= end_date)
324 if before_id is not None:
325 stamp = session.execute(
326 select(ArchivedMessage.timestamp).where(
327 ArchivedMessage.stanza_id == before_id,
328 ArchivedMessage.room_id == room_pk,
329 )
330 ).scalar_one_or_none()
331 if stamp is None:
332 raise XMPPError(
333 "item-not-found",
334 f"Message {before_id} not found",
335 )
336 q = q.where(ArchivedMessage.timestamp < stamp)
337 if after_id is not None:
338 stamp = session.execute(
339 select(ArchivedMessage.timestamp).where(
340 ArchivedMessage.stanza_id == after_id,
341 ArchivedMessage.room_id == room_pk,
342 )
343 ).scalar_one_or_none()
344 if stamp is None:
345 raise XMPPError(
346 "item-not-found",
347 f"Message {after_id} not found",
348 )
349 q = q.where(ArchivedMessage.timestamp > stamp)
350 if ids:
351 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
352 if sender is not None:
353 q = q.where(ArchivedMessage.author_jid == sender)
354 if flip:
355 q = q.order_by(ArchivedMessage.timestamp.desc())
356 else:
357 q = q.order_by(ArchivedMessage.timestamp.asc())
358 msgs = list(session.execute(q).scalars())
359 if ids and len(msgs) != len(ids):
360 raise XMPPError(
361 "item-not-found",
362 "One of the requested messages IDs could not be found "
363 "with the given constraints.",
364 )
365 if last_page_n is not None:
366 if flip:
367 msgs = msgs[:last_page_n]
368 else:
369 msgs = msgs[-last_page_n:]
370 for h in msgs:
371 yield HistoryMessage(
372 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
373 )
375 @staticmethod
376 def get_first(
377 session: Session, room_pk: int, with_legacy_id: bool = False
378 ) -> Optional[ArchivedMessage]:
379 q = (
380 select(ArchivedMessage)
381 .where(ArchivedMessage.room_id == room_pk)
382 .order_by(ArchivedMessage.timestamp.asc())
383 )
384 if with_legacy_id:
385 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
386 return session.execute(q).scalar()
388 @staticmethod
389 def get_last(
390 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
391 ) -> Optional[ArchivedMessage]:
392 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
394 if source is not None:
395 q = q.where(ArchivedMessage.source == source)
397 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
399 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]:
400 r = []
401 first = self.get_first(session, room_pk)
402 if first is not None:
403 r.append(MamMetadata(first.stanza_id, first.timestamp))
404 last = self.get_last(session, room_pk)
405 if last is not None:
406 r.append(MamMetadata(last.stanza_id, last.timestamp))
407 return r
409 @staticmethod
410 def get_most_recent_with_legacy_id(
411 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None
412 ) -> Optional[ArchivedMessage]:
413 q = (
414 select(ArchivedMessage)
415 .where(ArchivedMessage.room_id == room_pk)
416 .where(ArchivedMessage.legacy_id.isnot(None))
417 )
418 if source is not None:
419 q = q.where(ArchivedMessage.source == source)
420 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar()
422 @staticmethod
423 def get_least_recent_with_legacy_id_after(
424 session: Session,
425 room_pk: int,
426 after_id: str,
427 source: ArchivedMessageSource = ArchivedMessageSource.LIVE,
428 ) -> Optional[ArchivedMessage]:
429 after_timestamp = (
430 session.query(ArchivedMessage.timestamp)
431 .filter(ArchivedMessage.room_id == room_pk)
432 .filter(ArchivedMessage.legacy_id == after_id)
433 .scalar()
434 )
435 q = (
436 select(ArchivedMessage)
437 .where(ArchivedMessage.room_id == room_pk)
438 .where(ArchivedMessage.legacy_id.isnot(None))
439 .where(ArchivedMessage.source == source)
440 .where(ArchivedMessage.timestamp > after_timestamp)
441 )
442 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
444 @staticmethod
445 def get_by_legacy_id(
446 session: Session, room_pk: int, legacy_id: str
447 ) -> Optional[ArchivedMessage]:
448 return (
449 session.query(ArchivedMessage)
450 .filter(ArchivedMessage.room_id == room_pk)
451 .filter(ArchivedMessage.legacy_id == legacy_id)
452 .first()
453 )
456class RoomStore(UpdatedMixin):
457 model = Room
459 def __init__(self, session: Session) -> None:
460 super().__init__(session)
461 session.execute(
462 update(Room).values(
463 subject_setter=None,
464 user_resources=None,
465 history_filled=False,
466 participants_filled=False,
467 )
468 )
470 @staticmethod
471 def get_all(session: Session, user_pk: int) -> Iterator[Room]:
472 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk))
475class ParticipantStore:
476 def __init__(self, session: Session) -> None:
477 session.execute(delete(Participant))
479 @staticmethod
480 def get_all(
481 session, room_pk: int, user_included: bool = True
482 ) -> Iterator[Participant]:
483 query = select(Participant).where(Participant.room_id == room_pk)
484 if not user_included:
485 query = query.where(~Participant.is_user)
486 yield from session.scalars(query).unique()
489class BobStore:
490 _ATTR_MAP = {
491 "sha-1": "sha_1",
492 "sha1": "sha_1",
493 "sha-256": "sha_256",
494 "sha256": "sha_256",
495 "sha-512": "sha_512",
496 "sha512": "sha_512",
497 }
499 _ALG_MAP = {
500 "sha_1": hashlib.sha1,
501 "sha_256": hashlib.sha256,
502 "sha_512": hashlib.sha512,
503 }
505 def __init__(self) -> None:
506 self.root_dir = config.HOME_DIR / "slidge_stickers"
507 self.root_dir.mkdir(exist_ok=True)
509 @staticmethod
510 def __split_cid(cid: str) -> list[str]:
511 return cid.removesuffix("@bob.xmpp.org").split("+")
513 def __get_condition(self, cid: str):
514 alg_name, digest = self.__split_cid(cid)
515 attr = self._ATTR_MAP.get(alg_name)
516 if attr is None:
517 log.warning("Unknown hash algorithm: %s", alg_name)
518 return None
519 return getattr(Bob, attr) == digest
521 def get(self, session: Session, cid: str) -> Bob | None:
522 try:
523 return session.query(Bob).filter(self.__get_condition(cid)).scalar()
524 except ValueError:
525 log.warning("Cannot get Bob with CID: %s", cid)
526 return None
528 def get_sticker(self, session: Session, cid: str) -> Sticker | None:
529 bob = self.get(session, cid)
530 if bob is None:
531 return None
532 return Sticker(
533 self.root_dir / bob.file_name,
534 bob.content_type,
535 {h: getattr(bob, h) for h in self._ALG_MAP},
536 )
538 def get_bob(
539 self, session: Session, _jid, _node, _ifrom, cid: str
540 ) -> BitsOfBinary | None:
541 stored = self.get(session, cid)
542 if stored is None:
543 return None
544 bob = BitsOfBinary()
545 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
546 if stored.content_type is not None:
547 bob["type"] = stored.content_type
548 bob["cid"] = cid
549 return bob
551 def del_bob(self, session: Session, _jid, _node, _ifrom, cid: str) -> None:
552 try:
553 file_name = session.scalar(
554 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name)
555 )
556 except ValueError:
557 log.warning("Cannot delete Bob with CID: %s", cid)
558 return None
559 if file_name is None:
560 log.warning("No BoB with CID: %s", cid)
561 return None
562 (self.root_dir / file_name).unlink()
564 def set_bob(self, session: Session, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
565 cid = bob["cid"]
566 try:
567 alg_name, digest = self.__split_cid(cid)
568 except ValueError:
569 log.warning("Invalid CID provided: %s", cid)
570 return
571 attr = self._ATTR_MAP.get(alg_name)
572 if attr is None:
573 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name)
574 return
575 existing = self.get(session, bob["cid"])
576 if existing:
577 log.debug("Bob already exists")
578 return
579 bytes_ = bob["data"]
580 path = self.root_dir / uuid.uuid4().hex
581 if bob["type"]:
582 path = path.with_suffix(guess_extension(bob["type"]) or "")
583 path.write_bytes(bytes_)
584 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
585 if hashes[attr] != digest:
586 path.unlink(missing_ok=True)
587 raise ValueError("Provided CID does not match calculated hash")
588 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
589 session.add(row)
592log = logging.getLogger(__name__)