Coverage for slidge / core / mixins / db.py: 68%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1import logging 

2import typing 

3from contextlib import contextmanager 

4 

5import sqlalchemy as sa 

6from sqlalchemy.exc import IntegrityError 

7 

8from ...db.meta import Base, JSONSerializable 

9from ...db.models import Contact, Participant, Room 

10 

11if typing.TYPE_CHECKING: 

12 from slidge import BaseGateway 

13 

14 

15class DBMixin: 

16 stored: Base 

17 xmpp: "BaseGateway" 

18 log: logging.Logger 

19 

20 def merge(self) -> None: 

21 with self.xmpp.store.session() as orm: 

22 self.stored = orm.merge(self.stored) 

23 

24 def commit(self, merge: bool = False) -> None: 

25 with self.xmpp.store.session(expire_on_commit=False) as orm: 

26 if merge: 

27 self.log.debug("Merging %s", self.stored) 

28 self.stored = orm.merge(self.stored) 

29 self.log.debug("Merged %s", self.stored) 

30 orm.add(self.stored) 

31 self.log.debug("Committing to DB") 

32 orm.commit() 

33 

34 

35class UpdateInfoMixin(DBMixin): 

36 """ 

37 This mixin just adds a context manager that prevents commiting to the DB 

38 on every attribute change. 

39 """ 

40 

41 stored: Contact | Room 

42 xmpp: "BaseGateway" 

43 log: logging.Logger 

44 

45 def __init__(self, *args: object, **kwargs: object) -> None: 

46 super().__init__(*args, **kwargs) 

47 self._updating_info = False 

48 self.__deserialize() 

49 

50 def __deserialize(self) -> None: 

51 if self.stored.extra_attributes is not None: 

52 self.deserialize_extra_attributes(self.stored.extra_attributes) 

53 

54 def refresh(self) -> None: 

55 with self.xmpp.store.session(expire_on_commit=False) as orm: 

56 orm.add(self.stored) 

57 orm.refresh(self.stored) 

58 self.__deserialize() 

59 

60 def serialize_extra_attributes(self) -> JSONSerializable | None: 

61 """ 

62 If you want custom attributes of your instance to be stored persistently 

63 to the DB, here is where you have to return them as a dict to be used in 

64 `deserialize_extra_attributes()`. 

65 

66 """ 

67 return None 

68 

69 def deserialize_extra_attributes(self, data: JSONSerializable) -> None: 

70 """ 

71 This is where you get the dict that you passed in 

72 `serialize_extra_attributes()`. 

73 

74 ⚠ Since it is serialized as json, dictionary keys are converted to strings! 

75 Be sure to convert to other types if necessary. 

76 """ 

77 pass 

78 

79 @contextmanager 

80 def updating_info(self, merge: bool = True) -> typing.Iterator[None]: 

81 self._updating_info = True 

82 yield 

83 self._updating_info = False 

84 self.stored.updated = True 

85 try: 

86 self.commit(merge=merge) 

87 except IntegrityError as e: 

88 from slidge.group import LegacyMUC 

89 

90 if not isinstance(self, LegacyMUC): 

91 raise e 

92 if not self._ALL_INFO_FILLED_ON_STARTUP: 

93 raise e 

94 with self.xmpp.store.session(expire_on_commit=False) as orm: 

95 if self.stored.id is None: 

96 self.stored.id = self.xmpp.store.rooms.get( 

97 orm, self.user_pk, legacy_id=str(self.legacy_id) 

98 ).id 

99 merged = orm.merge(self.stored) 

100 resources = set() 

101 participants: list[Participant] = [] 

102 for participant in merged.participants: 

103 if participant.resource in resources: 

104 self.session.log.warning("ditching: %s", participant) 

105 continue 

106 resources.add(participant.resource) 

107 merged.participants = participants 

108 self.stored = merged 

109 orm.add(self.stored) 

110 orm.commit() 

111 

112 def commit(self, merge: bool = False) -> None: 

113 if self._updating_info: 

114 self.log.debug("Not updating %s right now", self.stored) 

115 else: 

116 self.stored.extra_attributes = self.serialize_extra_attributes() 

117 super().commit(merge=merge) 

118 

119 def update_stored_attribute(self, **kwargs: object) -> None: 

120 for key, value in kwargs.items(): 

121 setattr(self.stored, key, value) 

122 if self._updating_info: 

123 return 

124 with self.xmpp.store.session() as orm: 

125 orm.execute( 

126 sa.update(self.stored.__class__) 

127 .where(self.stored.__class__.id == self.stored.id) 

128 .values(**kwargs) 

129 ) 

130 orm.commit()