Coverage for slidge / util / util.py: 80%
173 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 04:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 04:38 +0000
1import logging
2import mimetypes
3import re
4from collections.abc import Callable, Collection, Coroutine
5from functools import wraps
6from pathlib import Path
7from time import time
8from typing import (
9 Any,
10 ClassVar,
11 Concatenate,
12 NamedTuple,
13 ParamSpec,
14 Protocol,
15 TypeVar,
16)
18try:
19 import emoji
20except ImportError:
21 EMOJI_LIB_AVAILABLE = False
22else:
23 EMOJI_LIB_AVAILABLE = True
25from slixmpp.types import ExtPresenceShows, ResourceDict
27from .types import Mention
29try:
30 import magic
31except ImportError as e:
32 magic = None # type:ignore
33 logging.warning(
34 (
35 "Libmagic is not available: %s. "
36 "It's OK if you don't use fix-filename-suffix-mime-type."
37 ),
38 e,
39 )
42def fix_suffix(
43 path: Path, mime_type: str | None, file_name: str | None
44) -> tuple[str, str]:
45 guessed = magic.from_file(path, mime=True)
46 if guessed == mime_type:
47 log.debug("Magic and given MIME match")
48 else:
49 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type)
50 mime_type = guessed
52 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False)
54 name = Path(file_name) if file_name else Path(path.name)
56 suffix = name.suffix
58 if suffix in valid_suffix_list:
59 log.debug("Suffix %s is in %s", suffix, valid_suffix_list)
60 return str(name), guessed
62 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False)
63 if valid_suffix is None:
64 log.debug("No valid suffix found")
65 return str(name), guessed
67 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix)
68 return str(name.with_suffix(valid_suffix)), guessed
71class SubclassableOnce:
72 # To allow importing everything, including plugins, during tests
73 TEST_MODE: bool = False
74 __subclasses: ClassVar[
75 dict[type["SubclassableOnce"], type["SubclassableOnce"] | None]
76 ] = {}
78 def __init_subclass__(cls, **kwargs: object) -> None:
79 if SubclassableOnce not in cls.__bases__:
80 base = SubclassableOnce.__find_direct_child(cls)
81 existing = SubclassableOnce.__subclasses.get(base)
82 if existing is not None and not SubclassableOnce.TEST_MODE:
83 raise RuntimeError("This class must be subclassed once at most!")
84 cls.__subclasses[base] = cls
85 super().__init_subclass__(**kwargs)
87 @staticmethod
88 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]:
89 for base in cls.__bases__:
90 if issubclass(base, SubclassableOnce):
91 return base
92 else:
93 raise RuntimeError("wut")
95 @classmethod
96 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]":
97 try:
98 return cls.get_unique_subclass()
99 except AttributeError:
100 return cls
102 @classmethod
103 def get_unique_subclass(cls) -> "type[SubclassableOnce]":
104 existing = SubclassableOnce.__subclasses.get(cls)
105 if existing is None:
106 raise AttributeError("Could not find any subclass", cls)
107 return existing
109 @classmethod
110 def reset_subclass(cls) -> None:
111 log.debug("Resetting subclass of %s", cls)
112 cls.__subclasses[cls] = None
115def is_valid_phone_number(phone: str | None) -> bool:
116 if phone is None:
117 return False
118 match = re.match(r"\+\d.*", phone)
119 if match is None:
120 return False
121 return match[0] == phone
124def strip_illegal_chars(s: str, repl: str = "") -> str:
125 return ILLEGAL_XML_CHARS_RE.sub(repl, s)
128# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve
129ILLEGAL = [
130 (0x00, 0x08),
131 (0x0B, 0x0C),
132 (0x0E, 0x1F),
133 (0x7F, 0x84),
134 (0x86, 0x9F),
135 (0xFDD0, 0xFDDF),
136 (0xFFFE, 0xFFFF),
137 (0x1FFFE, 0x1FFFF),
138 (0x2FFFE, 0x2FFFF),
139 (0x3FFFE, 0x3FFFF),
140 (0x4FFFE, 0x4FFFF),
141 (0x5FFFE, 0x5FFFF),
142 (0x6FFFE, 0x6FFFF),
143 (0x7FFFE, 0x7FFFF),
144 (0x8FFFE, 0x8FFFF),
145 (0x9FFFE, 0x9FFFF),
146 (0xAFFFE, 0xAFFFF),
147 (0xBFFFE, 0xBFFFF),
148 (0xCFFFE, 0xCFFFF),
149 (0xDFFFE, 0xDFFFF),
150 (0xEFFFE, 0xEFFFF),
151 (0xFFFFE, 0xFFFFF),
152 (0x10FFFE, 0x10FFFF),
153]
155ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL]
156XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]"
157ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX)
160# from https://stackoverflow.com/a/35804945/5902284
161def addLoggingLevel(
162 levelName: str = "TRACE",
163 levelNum: int = logging.DEBUG - 5,
164 methodName: str | None = None,
165) -> None:
166 """
167 Comprehensively adds a new logging level to the `logging` module and the
168 currently configured logging class.
170 `levelName` becomes an attribute of the `logging` module with the value
171 `levelNum`. `methodName` becomes a convenience method for both `logging`
172 itself and the class returned by `logging.getLoggerClass()` (usually just
173 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
174 used.
176 To avoid accidental clobberings of existing attributes, this method will
177 raise an `AttributeError` if the level name is already an attribute of the
178 `logging` module or if the method name is already present
180 Example
181 -------
182 >>> addLoggingLevel('TRACE', logging.DEBUG - 5)
183 >>> logging.getLogger(__name__).setLevel("TRACE")
184 >>> logging.getLogger(__name__).trace('that worked')
185 >>> logging.trace('so did this')
186 >>> logging.TRACE
187 5
189 """
190 if not methodName:
191 methodName = levelName.lower()
193 if hasattr(logging, levelName):
194 log.debug(f"{levelName} already defined in logging module")
195 return
196 if hasattr(logging, methodName):
197 log.debug(f"{methodName} already defined in logging module")
198 return
199 if hasattr(logging.getLoggerClass(), methodName):
200 log.debug(f"{methodName} already defined in logger class")
201 return
203 # This method was inspired by the answers to Stack Overflow post
204 # http://stackoverflow.com/q/2183233/2988730, especially
205 # http://stackoverflow.com/a/13638084/2988730
206 def logForLevel(self, message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa
207 if self.isEnabledFor(levelNum):
208 self._log(levelNum, message, args, **kwargs)
210 def logToRoot(message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa
211 logging.log(levelNum, message, *args, **kwargs)
213 logging.addLevelName(levelNum, levelName)
214 setattr(logging, levelName, levelNum)
215 setattr(logging.getLoggerClass(), methodName, logForLevel)
216 setattr(logging, methodName, logToRoot)
219class SlidgeLogger(logging.Logger):
220 def trace(self) -> None:
221 pass
224log = logging.getLogger(__name__)
227def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None:
228 if len(resources) == 0:
229 return None
231 if len(resources) == 1:
232 return next(iter(resources.values()))
234 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True)
236 if any(r["show"] == "" for r in resources.values()):
237 # if a client is "available", we're "available"
238 show: ExtPresenceShows = ""
239 else:
240 for r in by_priority:
241 if r["show"]:
242 show = r["show"]
243 break
244 else:
245 raise RuntimeError()
247 # if there are different statuses, we use the highest priority one,
248 # but we ignore resources without status, even with high priority
249 status = ""
250 for r in by_priority:
251 if r["status"]:
252 status = r["status"]
253 break
255 return {
256 "show": show,
257 "status": status,
258 "priority": 0,
259 }
262def remove_emoji_variation_selector_16(emoji: str) -> str:
263 # this is required for compatibility with dino, and maybe other future clients?
264 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode()
267NamedTupleT = TypeVar("NamedTupleT", bound=NamedTuple)
270def dict_to_named_tuple(data: dict[str, Any], cls: type[NamedTupleT]) -> NamedTupleT:
271 return cls(*(data.get(f) for f in cls._fields)) # type:ignore[arg-type]
274def replace_mentions(
275 text: str,
276 mentions: Collection[Mention] | None,
277 mapping: Callable[[Mention], str],
278) -> str:
279 if not mentions:
280 return text
282 cursor = 0
283 pieces = []
284 for mention in mentions:
285 try:
286 new_text = mapping(mention)
287 except Exception as exc:
288 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc)
289 new_text = mapping(mention.contact) # type:ignore
290 pieces.extend([text[cursor : mention.start], new_text])
291 cursor = mention.end
292 pieces.append(text[cursor:])
293 return "".join(pieces)
296class HasLogger(Protocol):
297 log: logging.Logger
300P = ParamSpec("P")
301T = TypeVar("T")
302Self = TypeVar("Self", bound=HasLogger)
303TimeItWrapped = Callable[Concatenate[Self, P], Coroutine[Any, Any, T]]
306def timeit(func: TimeItWrapped[Self, P, T]) -> TimeItWrapped[Self, P, T]:
307 @wraps(func)
308 async def wrapped(self: Self, /, *args: P.args, **kwargs: P.kwargs) -> T:
309 start = time()
310 r = await func(self, *args, **kwargs)
311 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000))
312 return r
314 return wrapped
317def strip_leading_emoji(text: str) -> str:
318 if not EMOJI_LIB_AVAILABLE:
319 return text
320 words = text.split(" ")
321 # is_emoji returns False for 🛷️ for obscure reasons,
322 # purely_emoji seems better
323 if len(words) > 1 and emoji.purely_emoji(words[0]):
324 return " ".join(words[1:])
325 return text
328async def noop_coro() -> None:
329 pass
332def add_quote_prefix(text: str) -> str:
333 """
334 Return multi-line text with leading quote marks (i.e. the ">" character).
335 """
336 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()