Coverage for slidge/db/meta.py: 82%
40 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1from __future__ import annotations
3import json
4from typing import Any, Union
6import sqlalchemy as sa
7from slixmpp import JID
8from sqlalchemy import Dialect
11class JIDType(sa.TypeDecorator[JID]):
12 """
13 Custom SQLAlchemy type for JIDs
14 """
16 impl = sa.types.TEXT
17 cache_ok = True
19 def process_bind_param(self, value: JID | None, dialect: sa.Dialect) -> str | None:
20 if value is None:
21 return value
22 return str(value)
24 def process_result_value(
25 self, value: str | None, dialect: sa.Dialect
26 ) -> JID | None:
27 if value is None:
28 return value
29 return JID(value)
32JSONSerializableTypes = Union[str, float, None, "JSONSerializable"]
33JSONSerializable = dict[str, JSONSerializableTypes]
36class JSONEncodedDict(sa.TypeDecorator[JSONSerializable]):
37 """
38 Custom SQLAlchemy type for dictionaries stored as JSON
40 Note that mutations of the dictionary are not detected by SQLAlchemy,
41 which is why use ``attributes.flag_modified()`` in ``UserStore.update()``
42 """
44 impl = sa.VARCHAR
46 cache_ok = True
48 def process_bind_param(
49 self, value: JSONSerializable | None, dialect: Dialect
50 ) -> str | None:
51 if value is None:
52 return None
53 return json.dumps(value)
55 def process_result_value(
56 self, value: Any | None, dialect: Dialect
57 ) -> JSONSerializable | None:
58 if value is None:
59 return None
60 return json.loads(value) # type:ignore
63class Base(sa.orm.DeclarativeBase):
64 type_annotation_map = {JSONSerializable: JSONEncodedDict, JID: JIDType}
67Base.metadata.naming_convention = {
68 "ix": "ix_%(column_0_label)s",
69 "uq": "uq_%(table_name)s_%(column_0_name)s",
70 "ck": "ck_%(table_name)s_`%(constraint_name)s`",
71 "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
72 "pk": "pk_%(table_name)s",
73}
76def get_engine(path: str, echo: bool = False) -> sa.Engine:
77 from sqlalchemy import log as sqlalchemy_log
79 engine = sa.create_engine(path)
80 if echo:
81 sqlalchemy_log._add_default_handler = lambda x: None # type:ignore
82 engine.echo = True
83 return engine