Coverage for slidge/db/store.py: 87%

291 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1from __future__ import annotations 

2 

3import hashlib 

4import logging 

5import uuid 

6from datetime import datetime, timedelta, timezone 

7from mimetypes import guess_extension 

8from typing import Collection, Iterator, Optional, Type 

9 

10from slixmpp.exceptions import XMPPError 

11from slixmpp.plugins.xep_0231.stanza import BitsOfBinary 

12from sqlalchemy import Engine, delete, select, update 

13from sqlalchemy.exc import InvalidRequestError 

14from sqlalchemy.orm import Session, attributes, sessionmaker 

15 

16from ..core import config 

17from ..util.archive_msg import HistoryMessage 

18from ..util.types import MamMetadata, Sticker 

19from .meta import Base 

20from .models import ( 

21 ArchivedMessage, 

22 ArchivedMessageSource, 

23 Bob, 

24 Contact, 

25 ContactSent, 

26 DirectMessages, 

27 DirectThreads, 

28 GatewayUser, 

29 GroupMessages, 

30 GroupThreads, 

31 Participant, 

32 Room, 

33) 

34 

35 

36class UpdatedMixin: 

37 model: Type[Base] = NotImplemented 

38 

39 def __init__(self, session: Session) -> None: 

40 session.execute(update(self.model).values(updated=False)) 

41 

42 def get_by_pk(self, session: Session, pk: int) -> Type[Base]: 

43 stmt = select(self.model).where(self.model.id == pk) # type:ignore 

44 return session.scalar(stmt) 

45 

46 

47class SlidgeStore: 

48 def __init__(self, engine: Engine) -> None: 

49 self._engine = engine 

50 self.session = sessionmaker(engine) 

51 

52 self.users = UserStore(self.session) 

53 self.avatars = AvatarStore(self.session) 

54 self.id_map = IdMapStore() 

55 self.bob = BobStore() 

56 with self.session() as session: 

57 self.contacts = ContactStore(session) 

58 self.mam = MAMStore(session, self.session) 

59 self.rooms = RoomStore(session) 

60 self.participants = ParticipantStore(session) 

61 session.commit() 

62 

63 

64class UserStore: 

65 def __init__(self, session_maker) -> None: 

66 self.session = session_maker 

67 

68 def update(self, user: GatewayUser) -> None: 

69 with self.session(expire_on_commit=False) as session: 

70 # https://github.com/sqlalchemy/sqlalchemy/discussions/6473 

71 try: 

72 attributes.flag_modified(user, "legacy_module_data") 

73 attributes.flag_modified(user, "preferences") 

74 except InvalidRequestError: 

75 pass 

76 session.add(user) 

77 session.commit() 

78 

79 

80class AvatarStore: 

81 def __init__(self, session_maker) -> None: 

82 self.session = session_maker 

83 

84 

85LegacyToXmppType = ( 

86 Type[DirectMessages] 

87 | Type[DirectThreads] 

88 | Type[GroupMessages] 

89 | Type[GroupThreads] 

90) 

91 

92 

93class IdMapStore: 

94 @staticmethod 

95 def _set( 

96 session: Session, 

97 foreign_key: int, 

98 legacy_id: str, 

99 xmpp_ids: list[str], 

100 type_: LegacyToXmppType, 

101 ) -> None: 

102 kwargs = dict(foreign_key=foreign_key, legacy_id=legacy_id) 

103 ids = session.scalars( 

104 select(type_.id).filter( 

105 type_.foreign_key == foreign_key, type_.legacy_id == legacy_id 

106 ) 

107 ) 

108 if ids: 

109 log.debug("Resetting legacy ID %s", legacy_id) 

110 session.execute(delete(type_).where(type_.id.in_(ids))) 

111 for xmpp_id in xmpp_ids: 

112 msg = type_(xmpp_id=xmpp_id, **kwargs) 

113 session.add(msg) 

114 

115 def set_thread( 

116 self, 

117 session: Session, 

118 foreign_key: int, 

119 legacy_id: str, 

120 xmpp_id: str, 

121 group: bool, 

122 ) -> None: 

123 self._set( 

124 session, 

125 foreign_key, 

126 legacy_id, 

127 [xmpp_id], 

128 GroupThreads if group else DirectThreads, 

129 ) 

130 

131 def set_msg( 

132 self, 

133 session: Session, 

134 foreign_key: int, 

135 legacy_id: str, 

136 xmpp_ids: list[str], 

137 group: bool, 

138 ) -> None: 

139 self._set( 

140 session, 

141 foreign_key, 

142 legacy_id, 

143 xmpp_ids, 

144 GroupMessages if group else DirectMessages, 

145 ) 

146 

147 @staticmethod 

148 def _get( 

149 session: Session, foreign_key: int, legacy_id: str, type_: LegacyToXmppType 

150 ) -> list[str]: 

151 return list( 

152 session.scalars( 

153 select(type_.xmpp_id).filter_by( 

154 foreign_key=foreign_key, legacy_id=legacy_id 

155 ) 

156 ) 

157 ) 

158 

159 def get_xmpp( 

160 self, session: Session, foreign_key: int, legacy_id: str, group: bool 

161 ) -> list[str]: 

162 return self._get( 

163 session, 

164 foreign_key, 

165 legacy_id, 

166 GroupMessages if group else DirectMessages, 

167 ) 

168 

169 @staticmethod 

170 def _get_legacy( 

171 session: Session, foreign_key: int, xmpp_id: str, type_: LegacyToXmppType 

172 ) -> Optional[str]: 

173 return session.scalar( 

174 select(type_.legacy_id).filter_by(foreign_key=foreign_key, xmpp_id=xmpp_id) 

175 ) 

176 

177 def get_legacy( 

178 self, session: Session, foreign_key: int, xmpp_id: str, group: bool 

179 ) -> Optional[str]: 

180 return self._get_legacy( 

181 session, 

182 foreign_key, 

183 xmpp_id, 

184 GroupMessages if group else DirectMessages, 

185 ) 

186 

187 def get_thread( 

188 self, session: Session, foreign_key: int, xmpp_id: str, group: bool 

189 ) -> Optional[str]: 

190 return self._get_legacy( 

191 session, 

192 foreign_key, 

193 xmpp_id, 

194 GroupThreads if group else DirectThreads, 

195 ) 

196 

197 @staticmethod 

198 def was_sent_by_user( 

199 session: Session, foreign_key: int, legacy_id: str, group: bool 

200 ) -> bool: 

201 type_ = GroupMessages if group else DirectMessages 

202 return ( 

203 session.scalar( 

204 select(type_.id).filter_by(foreign_key=foreign_key, legacy_id=legacy_id) 

205 ) 

206 is not None 

207 ) 

208 

209 

210class ContactStore(UpdatedMixin): 

211 model = Contact 

212 

213 def __init__(self, session: Session) -> None: 

214 super().__init__(session) 

215 session.execute(update(Contact).values(cached_presence=False)) 

216 

217 @staticmethod 

218 def add_to_sent(session: Session, contact_pk: int, msg_id: str) -> None: 

219 if ( 

220 session.query(ContactSent.id) 

221 .where(ContactSent.contact_id == contact_pk) 

222 .where(ContactSent.msg_id == msg_id) 

223 .first() 

224 ) is not None: 

225 log.warning("Contact %s has already sent message %s", contact_pk, msg_id) 

226 return 

227 new = ContactSent(contact_id=contact_pk, msg_id=msg_id) 

228 session.add(new) 

229 

230 @staticmethod 

231 def pop_sent_up_to(session: Session, contact_pk: int, msg_id: str) -> list[str]: 

232 result = [] 

233 to_del = [] 

234 for row in session.execute( 

235 select(ContactSent) 

236 .where(ContactSent.contact_id == contact_pk) 

237 .order_by(ContactSent.id) 

238 ).scalars(): 

239 to_del.append(row.id) 

240 result.append(row.msg_id) 

241 if row.msg_id == msg_id: 

242 break 

243 session.execute(delete(ContactSent).where(ContactSent.id.in_(to_del))) 

244 return result 

245 

246 

247class MAMStore: 

248 def __init__(self, session: Session, session_maker) -> None: 

249 self.session = session_maker 

250 session.execute( 

251 update(ArchivedMessage).values(source=ArchivedMessageSource.BACKFILL) 

252 ) 

253 

254 @staticmethod 

255 def nuke_older_than(session: Session, days: int) -> None: 

256 session.execute( 

257 delete(ArchivedMessage).where( 

258 ArchivedMessage.timestamp < datetime.now() - timedelta(days=days) 

259 ) 

260 ) 

261 

262 @staticmethod 

263 def add_message( 

264 session: Session, 

265 room_pk: int, 

266 message: HistoryMessage, 

267 archive_only: bool, 

268 legacy_msg_id: Optional[str], 

269 ) -> None: 

270 source = ( 

271 ArchivedMessageSource.BACKFILL 

272 if archive_only 

273 else ArchivedMessageSource.LIVE 

274 ) 

275 existing = session.execute( 

276 select(ArchivedMessage) 

277 .where(ArchivedMessage.room_id == room_pk) 

278 .where(ArchivedMessage.stanza_id == message.id) 

279 ).scalar() 

280 if existing is None and legacy_msg_id is not None: 

281 existing = session.execute( 

282 select(ArchivedMessage) 

283 .where(ArchivedMessage.room_id == room_pk) 

284 .where(ArchivedMessage.legacy_id == legacy_msg_id) 

285 ).scalar() 

286 if existing is not None: 

287 log.debug("Updating message %s in room %s", message.id, room_pk) 

288 existing.timestamp = message.when 

289 existing.stanza = str(message.stanza) 

290 existing.author_jid = message.stanza.get_from() 

291 existing.source = source 

292 existing.legacy_id = legacy_msg_id 

293 session.add(existing) 

294 return 

295 mam_msg = ArchivedMessage( 

296 stanza_id=message.id, 

297 timestamp=message.when, 

298 stanza=str(message.stanza), 

299 author_jid=message.stanza.get_from(), 

300 room_id=room_pk, 

301 source=source, 

302 legacy_id=legacy_msg_id, 

303 ) 

304 session.add(mam_msg) 

305 

306 @staticmethod 

307 def get_messages( 

308 session: Session, 

309 room_pk: int, 

310 start_date: Optional[datetime] = None, 

311 end_date: Optional[datetime] = None, 

312 before_id: Optional[str] = None, 

313 after_id: Optional[str] = None, 

314 ids: Collection[str] = (), 

315 last_page_n: Optional[int] = None, 

316 sender: Optional[str] = None, 

317 flip: bool = False, 

318 ) -> Iterator[HistoryMessage]: 

319 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

320 if start_date is not None: 

321 q = q.where(ArchivedMessage.timestamp >= start_date) 

322 if end_date is not None: 

323 q = q.where(ArchivedMessage.timestamp <= end_date) 

324 if before_id is not None: 

325 stamp = session.execute( 

326 select(ArchivedMessage.timestamp).where( 

327 ArchivedMessage.stanza_id == before_id, 

328 ArchivedMessage.room_id == room_pk, 

329 ) 

330 ).scalar_one_or_none() 

331 if stamp is None: 

332 raise XMPPError( 

333 "item-not-found", 

334 f"Message {before_id} not found", 

335 ) 

336 q = q.where(ArchivedMessage.timestamp < stamp) 

337 if after_id is not None: 

338 stamp = session.execute( 

339 select(ArchivedMessage.timestamp).where( 

340 ArchivedMessage.stanza_id == after_id, 

341 ArchivedMessage.room_id == room_pk, 

342 ) 

343 ).scalar_one_or_none() 

344 if stamp is None: 

345 raise XMPPError( 

346 "item-not-found", 

347 f"Message {after_id} not found", 

348 ) 

349 q = q.where(ArchivedMessage.timestamp > stamp) 

350 if ids: 

351 q = q.filter(ArchivedMessage.stanza_id.in_(ids)) 

352 if sender is not None: 

353 q = q.where(ArchivedMessage.author_jid == sender) 

354 if flip: 

355 q = q.order_by(ArchivedMessage.timestamp.desc()) 

356 else: 

357 q = q.order_by(ArchivedMessage.timestamp.asc()) 

358 msgs = list(session.execute(q).scalars()) 

359 if ids and len(msgs) != len(ids): 

360 raise XMPPError( 

361 "item-not-found", 

362 "One of the requested messages IDs could not be found " 

363 "with the given constraints.", 

364 ) 

365 if last_page_n is not None: 

366 if flip: 

367 msgs = msgs[:last_page_n] 

368 else: 

369 msgs = msgs[-last_page_n:] 

370 for h in msgs: 

371 yield HistoryMessage( 

372 stanza=str(h.stanza), when=h.timestamp.replace(tzinfo=timezone.utc) 

373 ) 

374 

375 @staticmethod 

376 def get_first( 

377 session: Session, room_pk: int, with_legacy_id: bool = False 

378 ) -> Optional[ArchivedMessage]: 

379 q = ( 

380 select(ArchivedMessage) 

381 .where(ArchivedMessage.room_id == room_pk) 

382 .order_by(ArchivedMessage.timestamp.asc()) 

383 ) 

384 if with_legacy_id: 

385 q = q.filter(ArchivedMessage.legacy_id.isnot(None)) 

386 return session.execute(q).scalar() 

387 

388 @staticmethod 

389 def get_last( 

390 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None 

391 ) -> Optional[ArchivedMessage]: 

392 q = select(ArchivedMessage).where(ArchivedMessage.room_id == room_pk) 

393 

394 if source is not None: 

395 q = q.where(ArchivedMessage.source == source) 

396 

397 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

398 

399 def get_first_and_last(self, session: Session, room_pk: int) -> list[MamMetadata]: 

400 r = [] 

401 first = self.get_first(session, room_pk) 

402 if first is not None: 

403 r.append(MamMetadata(first.stanza_id, first.timestamp)) 

404 last = self.get_last(session, room_pk) 

405 if last is not None: 

406 r.append(MamMetadata(last.stanza_id, last.timestamp)) 

407 return r 

408 

409 @staticmethod 

410 def get_most_recent_with_legacy_id( 

411 session: Session, room_pk: int, source: Optional[ArchivedMessageSource] = None 

412 ) -> Optional[ArchivedMessage]: 

413 q = ( 

414 select(ArchivedMessage) 

415 .where(ArchivedMessage.room_id == room_pk) 

416 .where(ArchivedMessage.legacy_id.isnot(None)) 

417 ) 

418 if source is not None: 

419 q = q.where(ArchivedMessage.source == source) 

420 return session.execute(q.order_by(ArchivedMessage.timestamp.desc())).scalar() 

421 

422 @staticmethod 

423 def get_least_recent_with_legacy_id_after( 

424 session: Session, 

425 room_pk: int, 

426 after_id: str, 

427 source: ArchivedMessageSource = ArchivedMessageSource.LIVE, 

428 ) -> Optional[ArchivedMessage]: 

429 after_timestamp = ( 

430 session.query(ArchivedMessage.timestamp) 

431 .filter(ArchivedMessage.room_id == room_pk) 

432 .filter(ArchivedMessage.legacy_id == after_id) 

433 .scalar() 

434 ) 

435 q = ( 

436 select(ArchivedMessage) 

437 .where(ArchivedMessage.room_id == room_pk) 

438 .where(ArchivedMessage.legacy_id.isnot(None)) 

439 .where(ArchivedMessage.source == source) 

440 .where(ArchivedMessage.timestamp > after_timestamp) 

441 ) 

442 return session.execute(q.order_by(ArchivedMessage.timestamp.asc())).scalar() 

443 

444 @staticmethod 

445 def get_by_legacy_id( 

446 session: Session, room_pk: int, legacy_id: str 

447 ) -> Optional[ArchivedMessage]: 

448 return ( 

449 session.query(ArchivedMessage) 

450 .filter(ArchivedMessage.room_id == room_pk) 

451 .filter(ArchivedMessage.legacy_id == legacy_id) 

452 .first() 

453 ) 

454 

455 

456class RoomStore(UpdatedMixin): 

457 model = Room 

458 

459 def __init__(self, session: Session) -> None: 

460 super().__init__(session) 

461 session.execute( 

462 update(Room).values( 

463 subject_setter=None, 

464 user_resources=None, 

465 history_filled=False, 

466 participants_filled=False, 

467 ) 

468 ) 

469 

470 @staticmethod 

471 def get_all(session: Session, user_pk: int) -> Iterator[Room]: 

472 yield from session.scalars(select(Room).where(Room.user_account_id == user_pk)) 

473 

474 

475class ParticipantStore: 

476 def __init__(self, session: Session) -> None: 

477 session.execute(delete(Participant)) 

478 

479 @staticmethod 

480 def get_all( 

481 session, room_pk: int, user_included: bool = True 

482 ) -> Iterator[Participant]: 

483 query = select(Participant).where(Participant.room_id == room_pk) 

484 if not user_included: 

485 query = query.where(~Participant.is_user) 

486 yield from session.scalars(query).unique() 

487 

488 

489class BobStore: 

490 _ATTR_MAP = { 

491 "sha-1": "sha_1", 

492 "sha1": "sha_1", 

493 "sha-256": "sha_256", 

494 "sha256": "sha_256", 

495 "sha-512": "sha_512", 

496 "sha512": "sha_512", 

497 } 

498 

499 _ALG_MAP = { 

500 "sha_1": hashlib.sha1, 

501 "sha_256": hashlib.sha256, 

502 "sha_512": hashlib.sha512, 

503 } 

504 

505 def __init__(self) -> None: 

506 self.root_dir = config.HOME_DIR / "slidge_stickers" 

507 self.root_dir.mkdir(exist_ok=True) 

508 

509 @staticmethod 

510 def __split_cid(cid: str) -> list[str]: 

511 return cid.removesuffix("@bob.xmpp.org").split("+") 

512 

513 def __get_condition(self, cid: str): 

514 alg_name, digest = self.__split_cid(cid) 

515 attr = self._ATTR_MAP.get(alg_name) 

516 if attr is None: 

517 log.warning("Unknown hash algorithm: %s", alg_name) 

518 return None 

519 return getattr(Bob, attr) == digest 

520 

521 def get(self, session: Session, cid: str) -> Bob | None: 

522 try: 

523 return session.query(Bob).filter(self.__get_condition(cid)).scalar() 

524 except ValueError: 

525 log.warning("Cannot get Bob with CID: %s", cid) 

526 return None 

527 

528 def get_sticker(self, session: Session, cid: str) -> Sticker | None: 

529 bob = self.get(session, cid) 

530 if bob is None: 

531 return None 

532 return Sticker( 

533 self.root_dir / bob.file_name, 

534 bob.content_type, 

535 {h: getattr(bob, h) for h in self._ALG_MAP}, 

536 ) 

537 

538 def get_bob( 

539 self, session: Session, _jid, _node, _ifrom, cid: str 

540 ) -> BitsOfBinary | None: 

541 stored = self.get(session, cid) 

542 if stored is None: 

543 return None 

544 bob = BitsOfBinary() 

545 bob["data"] = (self.root_dir / stored.file_name).read_bytes() 

546 if stored.content_type is not None: 

547 bob["type"] = stored.content_type 

548 bob["cid"] = cid 

549 return bob 

550 

551 def del_bob(self, session: Session, _jid, _node, _ifrom, cid: str) -> None: 

552 try: 

553 file_name = session.scalar( 

554 delete(Bob).where(self.__get_condition(cid)).returning(Bob.file_name) 

555 ) 

556 except ValueError: 

557 log.warning("Cannot delete Bob with CID: %s", cid) 

558 return None 

559 if file_name is None: 

560 log.warning("No BoB with CID: %s", cid) 

561 return None 

562 (self.root_dir / file_name).unlink() 

563 

564 def set_bob(self, session: Session, _jid, _node, _ifrom, bob: BitsOfBinary) -> None: 

565 cid = bob["cid"] 

566 try: 

567 alg_name, digest = self.__split_cid(cid) 

568 except ValueError: 

569 log.warning("Invalid CID provided: %s", cid) 

570 return 

571 attr = self._ATTR_MAP.get(alg_name) 

572 if attr is None: 

573 log.warning("Cannot set Bob: Unknown algorithm type: %s", alg_name) 

574 return 

575 existing = self.get(session, bob["cid"]) 

576 if existing: 

577 log.debug("Bob already exists") 

578 return 

579 bytes_ = bob["data"] 

580 path = self.root_dir / uuid.uuid4().hex 

581 if bob["type"]: 

582 path = path.with_suffix(guess_extension(bob["type"]) or "") 

583 path.write_bytes(bytes_) 

584 hashes = {k: v(bytes_).hexdigest() for k, v in self._ALG_MAP.items()} 

585 if hashes[attr] != digest: 

586 path.unlink(missing_ok=True) 

587 raise ValueError("Provided CID does not match calculated hash") 

588 row = Bob(file_name=path.name, content_type=bob["type"] or None, **hashes) 

589 session.add(row) 

590 

591 

592log = logging.getLogger(__name__)