Coverage for slidge/db/store.py: 87%
666 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1from __future__ import annotations
3import hashlib
4import json
5import logging
6import uuid
7from contextlib import contextmanager
8from datetime import datetime, timedelta, timezone
9from mimetypes import guess_extension
10from typing import TYPE_CHECKING, Collection, Iterator, Optional, Type
12from slixmpp import JID, Iq, Message, Presence
13from slixmpp.exceptions import XMPPError
14from slixmpp.plugins.xep_0231.stanza import BitsOfBinary
15from sqlalchemy import Engine, delete, select, update
16from sqlalchemy.orm import Session, attributes, load_only
17from sqlalchemy.sql.functions import count
19from ..core import config
20from ..util.archive_msg import HistoryMessage
21from ..util.types import URL, CachedPresence, ClientType
22from ..util.types import Hat as HatTuple
23from ..util.types import MamMetadata, MucAffiliation, MucRole, Sticker
24from .meta import Base
25from .models import (
26 ArchivedMessage,
27 ArchivedMessageSource,
28 Attachment,
29 Avatar,
30 Bob,
31 Contact,
32 ContactSent,
33 GatewayUser,
34 Hat,
35 LegacyIdsMulti,
36 Participant,
37 Room,
38 XmppIdsMulti,
39 XmppToLegacyEnum,
40 XmppToLegacyIds,
41 participant_hats,
42)
44if TYPE_CHECKING:
45 from ..contact.contact import LegacyContact
46 from ..group.participant import LegacyParticipant
47 from ..group.room import LegacyMUC
50class EngineMixin:
51 def __init__(self, engine: Engine):
52 self._engine = engine
54 @contextmanager
55 def session(self, **session_kwargs) -> Iterator[Session]:
56 global _session
57 if _session is not None:
58 yield _session
59 return
60 with Session(self._engine, **session_kwargs) as session:
61 _session = session
62 try:
63 yield session
64 finally:
65 _session = None
68class UpdatedMixin(EngineMixin):
69 model: Type[Base] = NotImplemented
71 def __init__(self, *a, **kw):
72 super().__init__(*a, **kw)
73 with self.session() as session:
74 session.execute(update(self.model).values(updated=False)) # type:ignore
75 session.commit()
77 def get_by_pk(self, pk: int) -> Optional[Base]:
78 with self.session() as session:
79 return session.execute(
80 select(self.model).where(self.model.id == pk) # type:ignore
81 ).scalar()
84class SlidgeStore(EngineMixin):
85 def __init__(self, engine: Engine):
86 super().__init__(engine)
87 self.users = UserStore(engine)
88 self.avatars = AvatarStore(engine)
89 self.contacts = ContactStore(engine)
90 self.mam = MAMStore(engine)
91 self.multi = MultiStore(engine)
92 self.attachments = AttachmentStore(engine)
93 self.rooms = RoomStore(engine)
94 self.sent = SentStore(engine)
95 self.participants = ParticipantStore(engine)
96 self.bob = BobStore(engine)
99class UserStore(EngineMixin):
100 def new(self, jid: JID, legacy_module_data: dict) -> GatewayUser:
101 if jid.resource:
102 jid = JID(jid.bare)
103 with self.session(expire_on_commit=False) as session:
104 user = session.execute(
105 select(GatewayUser).where(GatewayUser.jid == jid)
106 ).scalar()
107 if user is not None:
108 return user
109 user = GatewayUser(jid=jid, legacy_module_data=legacy_module_data)
110 session.add(user)
111 session.commit()
112 return user
114 def update(self, user: GatewayUser):
115 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473
116 attributes.flag_modified(user, "legacy_module_data")
117 attributes.flag_modified(user, "preferences")
118 with self.session() as session:
119 session.add(user)
120 session.commit()
122 def get_all(self) -> Iterator[GatewayUser]:
123 with self.session() as session:
124 yield from session.execute(select(GatewayUser)).scalars()
126 def get(self, jid: JID) -> Optional[GatewayUser]:
127 with self.session() as session:
128 return session.execute(
129 select(GatewayUser).where(GatewayUser.jid == jid.bare)
130 ).scalar()
132 def get_by_stanza(self, stanza: Iq | Message | Presence) -> Optional[GatewayUser]:
133 return self.get(stanza.get_from())
135 def delete(self, jid: JID) -> None:
136 with self.session() as session:
137 session.delete(self.get(jid))
138 session.commit()
140 def set_avatar_hash(self, pk: int, h: str | None = None) -> None:
141 with self.session() as session:
142 session.execute(
143 update(GatewayUser).where(GatewayUser.id == pk).values(avatar_hash=h)
144 )
145 session.commit()
148class AvatarStore(EngineMixin):
149 def get_by_url(self, url: URL | str) -> Optional[Avatar]:
150 with self.session() as session:
151 return session.execute(select(Avatar).where(Avatar.url == url)).scalar()
153 def get_by_pk(self, pk: int) -> Optional[Avatar]:
154 with self.session() as session:
155 return session.execute(select(Avatar).where(Avatar.id == pk)).scalar()
157 def delete_by_pk(self, pk: int):
158 with self.session() as session:
159 session.execute(delete(Avatar).where(Avatar.id == pk))
160 session.commit()
162 def get_all(self) -> Iterator[Avatar]:
163 with self.session() as session:
164 yield from session.execute(select(Avatar)).scalars()
167class SentStore(EngineMixin):
168 def set_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
169 with self.session() as session:
170 msg = (
171 session.query(XmppToLegacyIds)
172 .filter(XmppToLegacyIds.user_account_id == user_pk)
173 .filter(XmppToLegacyIds.legacy_id == legacy_id)
174 .filter(XmppToLegacyIds.xmpp_id == xmpp_id)
175 .scalar()
176 )
177 if msg is None:
178 msg = XmppToLegacyIds(user_account_id=user_pk)
179 else:
180 log.debug("Resetting a DM from sent store")
181 msg.legacy_id = legacy_id
182 msg.xmpp_id = xmpp_id
183 msg.type = XmppToLegacyEnum.DM
184 session.add(msg)
185 session.commit()
187 def get_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
188 with self.session() as session:
189 return session.execute(
190 select(XmppToLegacyIds.xmpp_id)
191 .where(XmppToLegacyIds.user_account_id == user_pk)
192 .where(XmppToLegacyIds.legacy_id == legacy_id)
193 .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
194 ).scalar()
196 def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
197 with self.session() as session:
198 return session.execute(
199 select(XmppToLegacyIds.legacy_id)
200 .where(XmppToLegacyIds.user_account_id == user_pk)
201 .where(XmppToLegacyIds.xmpp_id == xmpp_id)
202 .where(XmppToLegacyIds.type == XmppToLegacyEnum.DM)
203 ).scalar()
205 def set_group_message(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
206 with self.session() as session:
207 msg = XmppToLegacyIds(
208 user_account_id=user_pk,
209 legacy_id=legacy_id,
210 xmpp_id=xmpp_id,
211 type=XmppToLegacyEnum.GROUP_CHAT,
212 )
213 session.add(msg)
214 session.commit()
216 def get_group_xmpp_id(self, user_pk: int, legacy_id: str) -> Optional[str]:
217 with self.session() as session:
218 return session.execute(
219 select(XmppToLegacyIds.xmpp_id)
220 .where(XmppToLegacyIds.user_account_id == user_pk)
221 .where(XmppToLegacyIds.legacy_id == legacy_id)
222 .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
223 ).scalar()
225 def get_group_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
226 with self.session() as session:
227 return session.execute(
228 select(XmppToLegacyIds.legacy_id)
229 .where(XmppToLegacyIds.user_account_id == user_pk)
230 .where(XmppToLegacyIds.xmpp_id == xmpp_id)
231 .where(XmppToLegacyIds.type == XmppToLegacyEnum.GROUP_CHAT)
232 ).scalar()
234 def set_thread(self, user_pk: int, legacy_id: str, xmpp_id: str) -> None:
235 with self.session() as session:
236 msg = XmppToLegacyIds(
237 user_account_id=user_pk,
238 legacy_id=legacy_id,
239 xmpp_id=xmpp_id,
240 type=XmppToLegacyEnum.THREAD,
241 )
242 session.add(msg)
243 session.commit()
245 def get_legacy_thread(self, user_pk: int, xmpp_id: str) -> Optional[str]:
246 with self.session() as session:
247 return session.execute(
248 select(XmppToLegacyIds.legacy_id)
249 .where(XmppToLegacyIds.user_account_id == user_pk)
250 .where(XmppToLegacyIds.xmpp_id == xmpp_id)
251 .where(XmppToLegacyIds.type == XmppToLegacyEnum.THREAD)
252 ).scalar()
254 def was_sent_by_user(self, user_pk: int, legacy_id: str) -> bool:
255 with self.session() as session:
256 return (
257 session.execute(
258 select(XmppToLegacyIds.legacy_id)
259 .where(XmppToLegacyIds.user_account_id == user_pk)
260 .where(XmppToLegacyIds.legacy_id == legacy_id)
261 ).scalar()
262 is not None
263 )
266class ContactStore(UpdatedMixin):
267 model = Contact
269 def __init__(self, *a, **k):
270 super().__init__(*a, **k)
271 with self.session() as session:
272 session.execute(update(Contact).values(cached_presence=False))
273 session.commit()
275 def get_all(self, user_pk: int) -> Iterator[Contact]:
276 with self.session() as session:
277 yield from session.execute(
278 select(Contact).where(Contact.user_account_id == user_pk)
279 ).scalars()
281 def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Contact]:
282 with self.session() as session:
283 return session.execute(
284 select(Contact)
285 .where(Contact.jid == jid.bare)
286 .where(Contact.user_account_id == user_pk)
287 ).scalar()
289 def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Contact]:
290 with self.session() as session:
291 return session.execute(
292 select(Contact)
293 .where(Contact.legacy_id == legacy_id)
294 .where(Contact.user_account_id == user_pk)
295 ).scalar()
297 def update_nick(self, contact_pk: int, nick: Optional[str]) -> None:
298 with self.session() as session:
299 session.execute(
300 update(Contact).where(Contact.id == contact_pk).values(nick=nick)
301 )
302 session.commit()
304 def get_presence(self, contact_pk: int) -> Optional[CachedPresence]:
305 with self.session() as session:
306 presence = session.execute(
307 select(
308 Contact.last_seen,
309 Contact.ptype,
310 Contact.pstatus,
311 Contact.pshow,
312 Contact.cached_presence,
313 ).where(Contact.id == contact_pk)
314 ).first()
315 if presence is None or not presence[-1]:
316 return None
317 return CachedPresence(*presence[:-1])
319 def set_presence(self, contact_pk: int, presence: CachedPresence) -> None:
320 with self.session() as session:
321 session.execute(
322 update(Contact)
323 .where(Contact.id == contact_pk)
324 .values(**presence._asdict(), cached_presence=True)
325 )
326 session.commit()
328 def reset_presence(self, contact_pk: int):
329 with self.session() as session:
330 session.execute(
331 update(Contact)
332 .where(Contact.id == contact_pk)
333 .values(
334 last_seen=None,
335 ptype=None,
336 pstatus=None,
337 pshow=None,
338 cached_presence=False,
339 )
340 )
341 session.commit()
343 def set_avatar(
344 self, contact_pk: int, avatar_pk: Optional[int], avatar_legacy_id: Optional[str]
345 ):
346 with self.session() as session:
347 session.execute(
348 update(Contact)
349 .where(Contact.id == contact_pk)
350 .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id)
351 )
352 session.commit()
354 def get_avatar_legacy_id(self, contact_pk: int) -> Optional[str]:
355 with self.session() as session:
356 contact = session.execute(
357 select(Contact).where(Contact.id == contact_pk)
358 ).scalar()
359 if contact is None or contact.avatar is None:
360 return None
361 return contact.avatar_legacy_id
363 def update(self, contact: "LegacyContact", commit=True) -> int:
364 with self.session() as session:
365 if contact.contact_pk is None:
366 if contact.cached_presence is not None:
367 presence_kwargs = contact.cached_presence._asdict()
368 presence_kwargs["cached_presence"] = True
369 else:
370 presence_kwargs = {}
371 row = Contact(
372 jid=contact.jid.bare,
373 legacy_id=str(contact.legacy_id),
374 user_account_id=contact.user_pk,
375 **presence_kwargs,
376 )
377 else:
378 row = (
379 session.query(Contact)
380 .filter(Contact.id == contact.contact_pk)
381 .one()
382 )
383 row.nick = contact.name
384 row.is_friend = contact.is_friend
385 row.added_to_roster = contact.added_to_roster
386 row.updated = True
387 row.extra_attributes = contact.serialize_extra_attributes()
388 row.caps_ver = contact._caps_ver
389 row.vcard = contact._vcard
390 row.vcard_fetched = contact._vcard_fetched
391 row.client_type = contact.client_type
392 session.add(row)
393 if commit:
394 session.commit()
395 return row.id
397 def set_vcard(self, contact_pk: int, vcard: str | None) -> None:
398 with self.session() as session:
399 session.execute(
400 update(Contact)
401 .where(Contact.id == contact_pk)
402 .values(vcard=vcard, vcard_fetched=True)
403 )
404 session.commit()
406 def add_to_sent(self, contact_pk: int, msg_id: str) -> None:
407 with self.session() as session:
408 if (
409 session.query(ContactSent.id)
410 .where(ContactSent.contact_id == contact_pk)
411 .where(ContactSent.msg_id == msg_id)
412 .first()
413 ) is not None:
414 log.warning(
415 "Contact %s has already sent message %s", contact_pk, msg_id
416 )
417 return
418 new = ContactSent(contact_id=contact_pk, msg_id=msg_id)
419 session.add(new)
420 session.commit()
422 def pop_sent_up_to(self, contact_pk: int, msg_id: str) -> list[str]:
423 result = []
424 to_del = []
425 with self.session() as session:
426 for row in session.execute(
427 select(ContactSent)
428 .where(ContactSent.contact_id == contact_pk)
429 .order_by(ContactSent.id)
430 ).scalars():
431 to_del.append(row.id)
432 result.append(row.msg_id)
433 if row.msg_id == msg_id:
434 break
435 for row_id in to_del:
436 session.execute(delete(ContactSent).where(ContactSent.id == row_id))
437 return result
439 def set_friend(self, contact_pk: int, is_friend: bool) -> None:
440 with self.session() as session:
441 session.execute(
442 update(Contact)
443 .where(Contact.id == contact_pk)
444 .values(is_friend=is_friend)
445 )
446 session.commit()
448 def set_added_to_roster(self, contact_pk: int, value: bool) -> None:
449 with self.session() as session:
450 session.execute(
451 update(Contact)
452 .where(Contact.id == contact_pk)
453 .values(added_to_roster=value)
454 )
455 session.commit()
457 def delete(self, contact_pk: int) -> None:
458 with self.session() as session:
459 session.execute(delete(Contact).where(Contact.id == contact_pk))
460 session.commit()
462 def set_client_type(self, contact_pk: int, value: ClientType):
463 with self.session() as session:
464 session.execute(
465 update(Contact)
466 .where(Contact.id == contact_pk)
467 .values(client_type=value)
468 )
469 session.commit()
472class MAMStore(EngineMixin):
473 def __init__(self, *a, **kw):
474 super().__init__(*a, **kw)
475 with self.session() as session:
476 session.execute(
477 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL)
478 )
479 session.commit()
481 def nuke_older_than(self, days: int) -> None:
482 with self.session() as session:
483 session.execute(
484 delete(ArchivedMessage).where(
485 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days)
486 )
487 )
488 session.commit()
490 def add_message(
491 self,
492 room_pk: int,
493 message: HistoryMessage,
494 archive_only: bool,
495 legacy_msg_id: str | None,
496 ) -> None:
497 with self.session() as session:
498 source = (
499 ArchivedMessageSource.BACKFILL
500 if archive_only
501 else ArchivedMessageSource.LIVE
502 )
503 existing = session.execute(
504 select(ArchivedMessage)
505 .where(ArchivedMessage.room_id == room_pk)
506 .where(ArchivedMessage.stanza_id == message.id)
507 ).scalar()
508 if existing is None and legacy_msg_id is not None:
509 existing = session.execute(
510 select(ArchivedMessage)
511 .where(ArchivedMessage.room_id == room_pk)
512 .where(ArchivedMessage.legacy_id == legacy_msg_id)
513 ).scalar()
514 if existing is not None:
515 log.debug("Updating message %s in room %s", message.id, room_pk)
516 existing.timestamp = message.when
517 existing.stanza = str(message.stanza)
518 existing.author_jid = message.stanza.get_from()
519 existing.source = source
520 existing.legacy_id = legacy_msg_id
521 session.add(existing)
522 session.commit()
523 return
524 mam_msg = ArchivedMessage(
525 stanza_id=message.id,
526 timestamp=message.when,
527 stanza=str(message.stanza),
528 author_jid=message.stanza.get_from(),
529 room_id=room_pk,
530 source=source,
531 legacy_id=legacy_msg_id,
532 )
533 session.add(mam_msg)
534 session.commit()
536 def get_messages(
537 self,
538 room_pk: int,
539 start_date: Optional[datetime] = None,
540 end_date: Optional[datetime] = None,
541 before_id: Optional[str] = None,
542 after_id: Optional[str] = None,
543 ids: Collection[str] = (),
544 last_page_n: Optional[int] = None,
545 sender: Optional[str] = None,
546 flip=False,
547 ) -> Iterator[HistoryMessage]:
549 with self.session() as session:
550 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
551 if start_date is not None:
552 q = q.where(ArchivedMessage.timestamp >= start_date)
553 if end_date is not None:
554 q = q.where(ArchivedMessage.timestamp <= end_date)
555 if before_id is not None:
556 stamp = session.execute(
557 select(ArchivedMessage.timestamp).where(
558 ArchivedMessage.stanza_id == before_id
559 )
560 ).scalar()
561 if stamp is None:
562 raise XMPPError(
563 "item-not-found",
564 f"Message {before_id} not found",
565 )
566 q = q.where(ArchivedMessage.timestamp < stamp)
567 if after_id is not None:
568 stamp = session.execute(
569 select(ArchivedMessage.timestamp).where(
570 ArchivedMessage.stanza_id == after_id
571 )
572 ).scalar()
573 if stamp is None:
574 raise XMPPError(
575 "item-not-found",
576 f"Message {after_id} not found",
577 )
578 q = q.where(ArchivedMessage.timestamp > stamp)
579 if ids:
580 q = q.filter(ArchivedMessage.stanza_id.in_(ids))
581 if sender is not None:
582 q = q.where(ArchivedMessage.author_jid == sender)
583 if flip:
584 q = q.order_by(ArchivedMessage.timestamp.desc())
585 else:
586 q = q.order_by(ArchivedMessage.timestamp.asc())
587 msgs = list(session.execute(q).scalars())
588 if ids and len(msgs) != len(ids):
589 raise XMPPError(
590 "item-not-found",
591 "One of the requested messages IDs could not be found "
592 "with the given constraints.",
593 )
594 if last_page_n is not None:
595 if flip:
596 msgs = msgs[:last_page_n]
597 else:
598 msgs = msgs[-last_page_n:]
599 for h in msgs:
600 yield HistoryMessage(
601 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc)
602 )
604 def get_first(self, room_pk: int, with_legacy_id=False) -> ArchivedMessage | None:
605 with self.session() as session:
606 q = (
607 select(ArchivedMessage)
608 .where(ArchivedMessage.room_id == room_pk)
609 .order_by(ArchivedMessage.timestamp.asc())
610 )
611 if with_legacy_id:
612 q = q.filter(ArchivedMessage.legacy_id.isnot(None))
613 return session.execute(q).scalar()
615 def get_last(
616 self, room_pk: int, source: ArchivedMessageSource | None = None
617 ) -> ArchivedMessage | None:
618 with self.session() as session:
619 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk)
621 if source is not None:
622 q = q.where(ArchivedMessage.source == source)
624 return session.execute(
625 q.order_by(ArchivedMessage.timestamp.desc())
626 ).scalar()
628 def get_first_and_last(self, room_pk: int) -> list[MamMetadata]:
629 r = []
630 with self.session():
631 first = self.get_first(room_pk)
632 if first is not None:
633 r.append(MamMetadata(first.stanza_id, first.timestamp))
634 last = self.get_last(room_pk)
635 if last is not None:
636 r.append(MamMetadata(last.stanza_id, last.timestamp))
637 return r
639 def get_most_recent_with_legacy_id(
640 self, room_pk: int, source: ArchivedMessageSource | None = None
641 ) -> ArchivedMessage | None:
642 with self.session() as session:
643 q = (
644 select(ArchivedMessage)
645 .where(ArchivedMessage.room_id == room_pk)
646 .where(ArchivedMessage.legacy_id.isnot(None))
647 )
648 if source is not None:
649 q = q.where(ArchivedMessage.source == source)
650 return session.execute(
651 q.order_by(ArchivedMessage.timestamp.desc())
652 ).scalar()
654 def get_least_recent_with_legacy_id_after(
655 self, room_pk: int, after_id: str, source=ArchivedMessageSource.LIVE
656 ) -> ArchivedMessage | None:
657 with self.session() as session:
658 after_timestamp = (
659 session.query(ArchivedMessage.timestamp)
660 .filter(ArchivedMessage.room_id == room_pk)
661 .filter(ArchivedMessage.legacy_id == after_id)
662 .scalar()
663 )
664 q = (
665 select(ArchivedMessage)
666 .where(ArchivedMessage.room_id == room_pk)
667 .where(ArchivedMessage.legacy_id.isnot(None))
668 .where(ArchivedMessage.source == source)
669 .where(ArchivedMessage.timestamp > after_timestamp)
670 )
671 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar()
673 def get_by_legacy_id(self, room_pk: int, legacy_id: str) -> ArchivedMessage | None:
674 with self.session() as session:
675 return (
676 session.query(ArchivedMessage)
677 .filter(ArchivedMessage.room_id == room_pk)
678 .filter(ArchivedMessage.legacy_id == legacy_id)
679 .first()
680 )
683class MultiStore(EngineMixin):
684 def get_xmpp_ids(self, user_pk: int, xmpp_id: str) -> list[str]:
685 with self.session() as session:
686 multi = session.execute(
687 select(XmppIdsMulti)
688 .where(XmppIdsMulti.xmpp_id == xmpp_id)
689 .where(XmppIdsMulti.user_account_id == user_pk)
690 ).scalar()
691 if multi is None:
692 return []
693 return [m.xmpp_id for m in multi.legacy_ids_multi.xmpp_ids]
695 def set_xmpp_ids(
696 self, user_pk: int, legacy_msg_id: str, xmpp_ids: list[str], fail=False
697 ) -> None:
698 with self.session() as session:
699 existing = session.execute(
700 select(LegacyIdsMulti)
701 .where(LegacyIdsMulti.user_account_id == user_pk)
702 .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
703 ).scalar()
704 if existing is not None:
705 if fail:
706 raise
707 log.debug("Resetting multi for %s", legacy_msg_id)
708 session.execute(
709 delete(LegacyIdsMulti)
710 .where(LegacyIdsMulti.user_account_id == user_pk)
711 .where(LegacyIdsMulti.legacy_id == legacy_msg_id)
712 )
713 for i in xmpp_ids:
714 session.execute(
715 delete(XmppIdsMulti)
716 .where(XmppIdsMulti.user_account_id == user_pk)
717 .where(XmppIdsMulti.xmpp_id == i)
718 )
719 session.commit()
720 self.set_xmpp_ids(user_pk, legacy_msg_id, xmpp_ids, True)
721 return
723 row = LegacyIdsMulti(
724 user_account_id=user_pk,
725 legacy_id=legacy_msg_id,
726 xmpp_ids=[
727 XmppIdsMulti(user_account_id=user_pk, xmpp_id=i)
728 for i in xmpp_ids
729 if i
730 ],
731 )
732 session.add(row)
733 session.commit()
735 def get_legacy_id(self, user_pk: int, xmpp_id: str) -> Optional[str]:
736 with self.session() as session:
737 multi = session.execute(
738 select(XmppIdsMulti)
739 .where(XmppIdsMulti.xmpp_id == xmpp_id)
740 .where(XmppIdsMulti.user_account_id == user_pk)
741 ).scalar()
742 if multi is None:
743 return None
744 return multi.legacy_ids_multi.legacy_id
747class AttachmentStore(EngineMixin):
748 def get_url(self, legacy_file_id: str) -> Optional[str]:
749 with self.session() as session:
750 return session.execute(
751 select(Attachment.url).where(
752 Attachment.legacy_file_id == legacy_file_id
753 )
754 ).scalar()
756 def set_url(self, user_pk: int, legacy_file_id: str, url: str) -> None:
757 with self.session() as session:
758 att = session.execute(
759 select(Attachment)
760 .where(Attachment.legacy_file_id == legacy_file_id)
761 .where(Attachment.user_account_id == user_pk)
762 ).scalar()
763 if att is None:
764 att = Attachment(
765 legacy_file_id=legacy_file_id, url=url, user_account_id=user_pk
766 )
767 session.add(att)
768 else:
769 att.url = url
770 session.commit()
772 def get_sims(self, url: str) -> Optional[str]:
773 with self.session() as session:
774 return session.execute(
775 select(Attachment.sims).where(Attachment.url == url)
776 ).scalar()
778 def set_sims(self, url: str, sims: str) -> None:
779 with self.session() as session:
780 session.execute(
781 update(Attachment).where(Attachment.url == url).values(sims=sims)
782 )
783 session.commit()
785 def get_sfs(self, url: str) -> Optional[str]:
786 with self.session() as session:
787 return session.execute(
788 select(Attachment.sfs).where(Attachment.url == url)
789 ).scalar()
791 def set_sfs(self, url: str, sfs: str) -> None:
792 with self.session() as session:
793 session.execute(
794 update(Attachment).where(Attachment.url == url).values(sfs=sfs)
795 )
796 session.commit()
798 def remove(self, legacy_file_id: str) -> None:
799 with self.session() as session:
800 session.execute(
801 delete(Attachment).where(Attachment.legacy_file_id == legacy_file_id)
802 )
803 session.commit()
806class RoomStore(UpdatedMixin):
807 model = Room
809 def __init__(self, *a, **kw):
810 super().__init__(*a, **kw)
811 with self.session() as session:
812 session.execute(
813 update(Room).values(
814 subject_setter=None,
815 user_resources=None,
816 history_filled=False,
817 participants_filled=False,
818 )
819 )
820 session.commit()
822 def set_avatar(
823 self, room_pk: int, avatar_pk: int | None, avatar_legacy_id: str | None
824 ) -> None:
825 with self.session() as session:
826 session.execute(
827 update(Room)
828 .where(Room.id == room_pk)
829 .values(avatar_id=avatar_pk, avatar_legacy_id=avatar_legacy_id)
830 )
831 session.commit()
833 def get_avatar_legacy_id(self, room_pk: int) -> Optional[str]:
834 with self.session() as session:
835 room = session.execute(select(Room).where(Room.id == room_pk)).scalar()
836 if room is None or room.avatar is None:
837 return None
838 return room.avatar_legacy_id
840 def get_by_jid(self, user_pk: int, jid: JID) -> Optional[Room]:
841 if jid.resource:
842 raise TypeError
843 with self.session() as session:
844 return session.execute(
845 select(Room)
846 .where(Room.user_account_id == user_pk)
847 .where(Room.jid == jid)
848 ).scalar()
850 def get_by_legacy_id(self, user_pk: int, legacy_id: str) -> Optional[Room]:
851 with self.session() as session:
852 return session.execute(
853 select(Room)
854 .where(Room.user_account_id == user_pk)
855 .where(Room.legacy_id == legacy_id)
856 ).scalar()
858 def update_subject_setter(self, room_pk: int, subject_setter: str | None):
859 with self.session() as session:
860 session.execute(
861 update(Room)
862 .where(Room.id == room_pk)
863 .values(subject_setter=subject_setter)
864 )
865 session.commit()
867 def update(self, muc: "LegacyMUC") -> int:
868 with self.session() as session:
869 if muc.pk is None:
870 row = Room(
871 jid=muc.jid,
872 legacy_id=str(muc.legacy_id),
873 user_account_id=muc.user_pk,
874 )
875 else:
876 row = session.query(Room).filter(Room.id == muc.pk).one()
878 row.updated = True
879 row.extra_attributes = muc.serialize_extra_attributes()
880 row.name = muc.name
881 row.description = muc.description
882 row.user_resources = (
883 None
884 if not muc._user_resources
885 else json.dumps(list(muc._user_resources))
886 )
887 row.muc_type = muc.type
888 row.subject = muc.subject
889 row.subject_date = muc.subject_date
890 row.subject_setter = muc.subject_setter
891 row.participants_filled = muc._participants_filled
892 row.n_participants = muc._n_participants
893 row.user_nick = muc.user_nick
894 session.add(row)
895 session.commit()
896 return row.id
898 def update_subject_date(
899 self, room_pk: int, subject_date: Optional[datetime]
900 ) -> None:
901 with self.session() as session:
902 session.execute(
903 update(Room).where(Room.id == room_pk).values(subject_date=subject_date)
904 )
905 session.commit()
907 def update_subject(self, room_pk: int, subject: Optional[str]) -> None:
908 with self.session() as session:
909 session.execute(
910 update(Room).where(Room.id == room_pk).values(subject=subject)
911 )
912 session.commit()
914 def update_description(self, room_pk: int, desc: Optional[str]) -> None:
915 with self.session() as session:
916 session.execute(
917 update(Room).where(Room.id == room_pk).values(description=desc)
918 )
919 session.commit()
921 def update_name(self, room_pk: int, name: Optional[str]) -> None:
922 with self.session() as session:
923 session.execute(update(Room).where(Room.id == room_pk).values(name=name))
924 session.commit()
926 def update_n_participants(self, room_pk: int, n: Optional[int]) -> None:
927 with self.session() as session:
928 session.execute(
929 update(Room).where(Room.id == room_pk).values(n_participants=n)
930 )
931 session.commit()
933 def update_user_nick(self, room_pk, nick: str) -> None:
934 with self.session() as session:
935 session.execute(
936 update(Room).where(Room.id == room_pk).values(user_nick=nick)
937 )
938 session.commit()
940 def delete(self, room_pk: int) -> None:
941 with self.session() as session:
942 session.execute(delete(Room).where(Room.id == room_pk))
943 session.execute(delete(Participant).where(Participant.room_id == room_pk))
944 session.commit()
946 def set_resource(self, room_pk: int, resources: set[str]) -> None:
947 with self.session() as session:
948 session.execute(
949 update(Room)
950 .where(Room.id == room_pk)
951 .values(
952 user_resources=(
953 None if not resources else json.dumps(list(resources))
954 )
955 )
956 )
957 session.commit()
959 def nickname_is_available(self, room_pk: int, nickname: str) -> bool:
960 with self.session() as session:
961 return (
962 session.execute(
963 select(Participant)
964 .where(Participant.room_id == room_pk)
965 .where(Participant.nickname == nickname)
966 ).scalar()
967 is None
968 )
970 def set_participants_filled(self, room_pk: int, val=True) -> None:
971 with self.session() as session:
972 session.execute(
973 update(Room).where(Room.id == room_pk).values(participants_filled=val)
974 )
975 session.commit()
977 def set_history_filled(self, room_pk: int, val=True) -> None:
978 with self.session() as session:
979 session.execute(
980 update(Room).where(Room.id == room_pk).values(history_filled=True)
981 )
982 session.commit()
984 def get_all(self, user_pk: int) -> Iterator[Room]:
985 with self.session() as session:
986 yield from session.execute(
987 select(Room).where(Room.user_account_id == user_pk)
988 ).scalars()
990 def get_all_jid_and_names(self, user_pk: int) -> Iterator[Room]:
991 with self.session() as session:
992 yield from session.scalars(
993 select(Room)
994 .filter(Room.user_account_id == user_pk)
995 .options(load_only(Room.jid, Room.name))
996 .order_by(Room.name)
997 ).all()
1000class ParticipantStore(EngineMixin):
1001 def __init__(self, *a, **kw):
1002 super().__init__(*a, **kw)
1003 with self.session() as session:
1004 session.execute(delete(participant_hats))
1005 session.execute(delete(Hat))
1006 session.execute(delete(Participant))
1007 session.commit()
1009 def add(self, room_pk: int, nickname: str) -> int:
1010 with self.session() as session:
1011 existing = session.execute(
1012 select(Participant.id)
1013 .where(Participant.room_id == room_pk)
1014 .where(Participant.nickname == nickname)
1015 ).scalar()
1016 if existing is not None:
1017 return existing
1018 participant = Participant(room_id=room_pk, nickname=nickname)
1019 session.add(participant)
1020 session.commit()
1021 return participant.id
1023 def get_by_nickname(self, room_pk: int, nickname: str) -> Optional[Participant]:
1024 with self.session() as session:
1025 return session.execute(
1026 select(Participant)
1027 .where(Participant.room_id == room_pk)
1028 .where(Participant.nickname == nickname)
1029 ).scalar()
1031 def get_by_resource(self, room_pk: int, resource: str) -> Optional[Participant]:
1032 with self.session() as session:
1033 return session.execute(
1034 select(Participant)
1035 .where(Participant.room_id == room_pk)
1036 .where(Participant.resource == resource)
1037 ).scalar()
1039 def get_by_contact(self, room_pk: int, contact_pk: int) -> Optional[Participant]:
1040 with self.session() as session:
1041 return session.execute(
1042 select(Participant)
1043 .where(Participant.room_id == room_pk)
1044 .where(Participant.contact_id == contact_pk)
1045 ).scalar()
1047 def get_all(self, room_pk: int, user_included=True) -> Iterator[Participant]:
1048 with self.session() as session:
1049 q = select(Participant).where(Participant.room_id == room_pk)
1050 if not user_included:
1051 q = q.where(~Participant.is_user)
1052 yield from session.execute(q).scalars()
1054 def get_for_contact(self, contact_pk: int) -> Iterator[Participant]:
1055 with self.session() as session:
1056 yield from session.execute(
1057 select(Participant).where(Participant.contact_id == contact_pk)
1058 ).scalars()
1060 def update(self, participant: "LegacyParticipant") -> None:
1061 with self.session() as session:
1062 session.execute(
1063 update(Participant)
1064 .where(Participant.id == participant.pk)
1065 .values(
1066 resource=participant.jid.resource,
1067 affiliation=participant.affiliation,
1068 role=participant.role,
1069 presence_sent=participant._presence_sent, # type:ignore
1070 # hats=[self.add_hat(h.uri, h.title) for h in participant._hats],
1071 is_user=participant.is_user,
1072 contact_id=(
1073 None
1074 if participant.contact is None
1075 else participant.contact.contact_pk
1076 ),
1077 )
1078 )
1079 session.commit()
1081 def add_hat(self, uri: str, title: str) -> Hat:
1082 with self.session() as session:
1083 existing = session.execute(
1084 select(Hat).where(Hat.uri == uri).where(Hat.title == title)
1085 ).scalar()
1086 if existing is not None:
1087 return existing
1088 hat = Hat(uri=uri, title=title)
1089 session.add(hat)
1090 session.commit()
1091 return hat
1093 def set_presence_sent(self, participant_pk: int) -> None:
1094 with self.session() as session:
1095 session.execute(
1096 update(Participant)
1097 .where(Participant.id == participant_pk)
1098 .values(presence_sent=True)
1099 )
1100 session.commit()
1102 def set_affiliation(self, participant_pk: int, affiliation: MucAffiliation) -> None:
1103 with self.session() as session:
1104 session.execute(
1105 update(Participant)
1106 .where(Participant.id == participant_pk)
1107 .values(affiliation=affiliation)
1108 )
1109 session.commit()
1111 def set_role(self, participant_pk: int, role: MucRole) -> None:
1112 with self.session() as session:
1113 session.execute(
1114 update(Participant)
1115 .where(Participant.id == participant_pk)
1116 .values(role=role)
1117 )
1118 session.commit()
1120 def set_hats(self, participant_pk: int, hats: list[HatTuple]) -> None:
1121 with self.session() as session:
1122 part = session.execute(
1123 select(Participant).where(Participant.id == participant_pk)
1124 ).scalar()
1125 if part is None:
1126 raise ValueError
1127 part.hats.clear()
1128 for h in hats:
1129 hat = self.add_hat(*h)
1130 if hat in part.hats:
1131 continue
1132 part.hats.append(hat)
1133 session.commit()
1135 def delete(self, participant_pk: int) -> None:
1136 with self.session() as session:
1137 session.execute(delete(Participant).where(Participant.id == participant_pk))
1139 def get_count(self, room_pk: int) -> int:
1140 with self.session() as session:
1141 return session.query(
1142 count(Participant.id).filter(Participant.room_id == room_pk)
1143 ).scalar()
1146class BobStore(EngineMixin):
1147 _ATTR_MAP = {
1148 "sha-1": "sha_1",
1149 "sha1": "sha_1",
1150 "sha-256": "sha_256",
1151 "sha256": "sha_256",
1152 "sha-512": "sha_512",
1153 "sha512": "sha_512",
1154 }
1156 _ALG_MAP = {
1157 "sha_1": hashlib.sha1,
1158 "sha_256": hashlib.sha256,
1159 "sha_512": hashlib.sha512,
1160 }
1162 def __init__(self, *a, **k):
1163 super().__init__(*a, **k)
1164 self.root_dir = config.HOME_DIR / "slidge_stickers"
1165 self.root_dir.mkdir(exist_ok=True)
1167 @staticmethod
1168 def __split_cid(cid: str) -> list[str]:
1169 return cid.removesuffix("@bob.xmpp.org").split("+")
1171 def __get_condition(self, cid: str):
1172 alg_name, digest = self.__split_cid(cid)
1173 attr = self._ATTR_MAP.get(alg_name)
1174 if attr is None:
1175 log.warning("Unknown hash algo: %s", alg_name)
1176 return None
1177 return getattr(Bob, attr) == digest
1179 def get(self, cid: str) -> Bob | None:
1180 with self.session() as session:
1181 try:
1182 return session.query(Bob).filter(self.__get_condition(cid)).scalar()
1183 except ValueError:
1184 log.warning("Cannot get Bob with CID: %s", cid)
1185 return None
1187 def get_sticker(self, cid: str) -> Sticker | None:
1188 bob = self.get(cid)
1189 if bob is None:
1190 return None
1191 return Sticker(
1192 self.root_dir / bob.file_name,
1193 bob.content_type,
1194 {h: getattr(bob, h) for h in self._ALG_MAP},
1195 )
1197 def get_bob(self, _jid, _node, _ifrom, cid: str) -> BitsOfBinary | None:
1198 stored = self.get(cid)
1199 if stored is None:
1200 return None
1201 bob = BitsOfBinary()
1202 bob["data"] = (self.root_dir / stored.file_name).read_bytes()
1203 if stored.content_type is not None:
1204 bob["type"] = stored.content_type
1205 bob["cid"] = cid
1206 return bob
1208 def del_bob(self, _jid, _node, _ifrom, cid: str) -> None:
1209 with self.session() as orm:
1210 try:
1211 file_name = orm.scalar(
1212 delete(Bob)
1213 .where(self.__get_condition(cid))
1214 .returning(Bob.file_name)
1215 )
1216 except ValueError:
1217 log.warning("Cannot delete Bob with CID: %s", cid)
1218 return None
1219 if file_name is None:
1220 log.warning("No BoB with CID: %s", cid)
1221 return None
1222 (self.root_dir / file_name).unlink()
1223 orm.commit()
1225 def set_bob(self, _jid, _node, _ifrom, bob: BitsOfBinary) -> None:
1226 cid = bob["cid"]
1227 try:
1228 alg_name, digest = self.__split_cid(cid)
1229 except ValueError:
1230 log.warning("Cannot set Bob with CID: %s", cid)
1231 return
1232 attr = self._ATTR_MAP.get(alg_name)
1233 if attr is None:
1234 log.warning("Cannot set BoB with unknown hash algo: %s", alg_name)
1235 return None
1236 with self.session() as orm:
1237 existing = self.get(bob["cid"])
1238 if existing is not None:
1239 log.debug("Bob already known")
1240 return
1241 bytes_ = bob["data"]
1242 path = self.root_dir / uuid.uuid4().hex
1243 if bob["type"]:
1244 path = path.with_suffix(guess_extension(bob["type"]) or "")
1245 path.write_bytes(bytes_)
1246 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()}
1247 if hashes[attr] != digest:
1248 raise ValueError(
1249 "The given CID does not correspond to the result of our hash"
1250 )
1251 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes)
1252 orm.add(row)
1253 orm.commit()
1256log = logging.getLogger(__name__)
1257_session: Optional[Session] = None