Coverage for slidge/db/meta.py: 82%

40 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1from __future__ import annotations 

2 

3import json 

4from typing import Any, Union 

5 

6import sqlalchemy as sa 

7from slixmpp import JID 

8from sqlalchemy import Dialect 

9 

10 

11class JIDType(sa.TypeDecorator[JID]): 

12 """ 

13 Custom SQLAlchemy type for JIDs 

14 """ 

15 

16 impl = sa.types.TEXT 

17 cache_ok = True 

18 

19 def process_bind_param(self, value: JID | None, dialect: sa.Dialect) -> str | None: 

20 if value is None: 

21 return value 

22 return str(value) 

23 

24 def process_result_value( 

25 self, value: str | None, dialect: sa.Dialect 

26 ) -> JID | None: 

27 if value is None: 

28 return value 

29 return JID(value) 

30 

31 

32JSONSerializableTypes = Union[str, float, None, "JSONSerializable"] 

33JSONSerializable = dict[str, JSONSerializableTypes] 

34 

35 

36class JSONEncodedDict(sa.TypeDecorator[JSONSerializable]): 

37 """ 

38 Custom SQLAlchemy type for dictionaries stored as JSON 

39 

40 Note that mutations of the dictionary are not detected by SQLAlchemy, 

41 which is why use ``attributes.flag_modified()`` in ``UserStore.update()`` 

42 """ 

43 

44 impl = sa.VARCHAR 

45 

46 cache_ok = True 

47 

48 def process_bind_param( 

49 self, value: JSONSerializable | None, dialect: Dialect 

50 ) -> str | None: 

51 if value is None: 

52 return None 

53 return json.dumps(value) 

54 

55 def process_result_value( 

56 self, value: Any | None, dialect: Dialect 

57 ) -> JSONSerializable | None: 

58 if value is None: 

59 return None 

60 return json.loads(value) # type:ignore 

61 

62 

63class Base(sa.orm.DeclarativeBase): 

64 type_annotation_map = {JSONSerializable: JSONEncodedDict, JID: JIDType} 

65 

66 

67Base.metadata.naming_convention = { 

68 "ix": "ix_%(column_0_label)s", 

69 "uq": "uq_%(table_name)s_%(column_0_name)s", 

70 "ck": "ck_%(table_name)s_`%(constraint_name)s`", 

71 "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", 

72 "pk": "pk_%(table_name)s", 

73} 

74 

75 

76def get_engine(path: str, echo: bool = False) -> sa.Engine: 

77 from sqlalchemy import log as sqlalchemy_log 

78 

79 engine = sa.create_engine(path) 

80 if echo: 

81 sqlalchemy_log._add_default_handler = lambda x: None # type:ignore 

82 engine.echo = True 

83 return engine