Coverage for slidge / core / mixins / db.py: 68%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-06 15:18 +0000

1import logging 

2import typing 

3from contextlib import contextmanager 

4 

5import sqlalchemy as sa 

6from sqlalchemy.exc import IntegrityError 

7 

8from ...db.models import Base, Contact, Participant, Room 

9 

10if typing.TYPE_CHECKING: 

11 from slidge import BaseGateway 

12 

13 

14class DBMixin: 

15 stored: Base 

16 xmpp: "BaseGateway" 

17 log: logging.Logger 

18 

19 def merge(self) -> None: 

20 with self.xmpp.store.session() as orm: 

21 self.stored = orm.merge(self.stored) 

22 

23 def commit(self, merge: bool = False) -> None: 

24 with self.xmpp.store.session(expire_on_commit=False) as orm: 

25 if merge: 

26 self.log.debug("Merging %s", self.stored) 

27 self.stored = orm.merge(self.stored) 

28 self.log.debug("Merged %s", self.stored) 

29 orm.add(self.stored) 

30 self.log.debug("Committing to DB") 

31 orm.commit() 

32 

33 

34class UpdateInfoMixin(DBMixin): 

35 """ 

36 This mixin just adds a context manager that prevents commiting to the DB 

37 on every attribute change. 

38 """ 

39 

40 stored: Contact | Room 

41 xmpp: "BaseGateway" 

42 log: logging.Logger 

43 

44 def __init__(self, *args, **kwargs) -> None: 

45 super().__init__(*args, **kwargs) 

46 self._updating_info = False 

47 self.__deserialize() 

48 

49 def __deserialize(self): 

50 if self.stored.extra_attributes is not None: 

51 self.deserialize_extra_attributes(self.stored.extra_attributes) 

52 

53 def refresh(self) -> None: 

54 with self.xmpp.store.session(expire_on_commit=False) as orm: 

55 orm.add(self.stored) 

56 orm.refresh(self.stored) 

57 self.__deserialize() 

58 

59 def serialize_extra_attributes(self) -> dict | None: 

60 """ 

61 If you want custom attributes of your instance to be stored persistently 

62 to the DB, here is where you have to return them as a dict to be used in 

63 `deserialize_extra_attributes()`. 

64 

65 """ 

66 return None 

67 

68 def deserialize_extra_attributes(self, data: dict) -> None: 

69 """ 

70 This is where you get the dict that you passed in 

71 `serialize_extra_attributes()`. 

72 

73 ⚠ Since it is serialized as json, dictionary keys are converted to strings! 

74 Be sure to convert to other types if necessary. 

75 """ 

76 pass 

77 

78 @contextmanager 

79 def updating_info(self, merge=True): 

80 self._updating_info = True 

81 yield 

82 self._updating_info = False 

83 self.stored.updated = True 

84 try: 

85 self.commit(merge=merge) 

86 except IntegrityError as e: 

87 from slidge.group import LegacyMUC 

88 

89 if not isinstance(self, LegacyMUC): 

90 raise e 

91 if not self._ALL_INFO_FILLED_ON_STARTUP: 

92 raise e 

93 with self.xmpp.store.session(expire_on_commit=False) as orm: 

94 if self.stored.id is None: 

95 self.stored.id = self.xmpp.store.rooms.get( 

96 orm, self.user_pk, legacy_id=str(self.legacy_id) 

97 ).id 

98 merged = orm.merge(self.stored) 

99 resources = set() 

100 participants: list[Participant] = [] 

101 for participant in merged.participants: 

102 if participant.resource in resources: 

103 self.session.log.warning("ditching: %s", participant) 

104 continue 

105 resources.add(participant.resource) 

106 merged.participants = participants 

107 self.stored = merged 

108 orm.add(self.stored) 

109 orm.commit() 

110 

111 def commit(self, merge: bool = False) -> None: 

112 if self._updating_info: 

113 self.log.debug("Not updating %s right now", self.stored) 

114 else: 

115 self.stored.extra_attributes = self.serialize_extra_attributes() 

116 super().commit(merge=merge) 

117 

118 def update_stored_attribute(self, **kwargs) -> None: 

119 for key, value in kwargs.items(): 

120 setattr(self.stored, key, value) 

121 if self._updating_info: 

122 return 

123 with self.xmpp.store.session() as orm: 

124 orm.execute( 

125 sa.update(self.stored.__class__) 

126 .where(self.stored.__class__.id == self.stored.id) 

127 .values(**kwargs) 

128 ) 

129 orm.commit()