Coverage for slidge / util / util.py: 78%
176 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1import logging
2import mimetypes
3import re
4import warnings
5from collections.abc import Awaitable, Callable
6from functools import wraps
7from pathlib import Path
8from time import time
9from typing import Any, ClassVar, NamedTuple, TypeVar
11try:
12 import emoji
13except ImportError:
14 EMOJI_LIB_AVAILABLE = False
15else:
16 EMOJI_LIB_AVAILABLE = True
18from slixmpp.types import ResourceDict
20from .types import Mention
22try:
23 import magic
24except ImportError as e:
25 magic = None # type:ignore
26 logging.warning(
27 (
28 "Libmagic is not available: %s. "
29 "It's OK if you don't use fix-filename-suffix-mime-type."
30 ),
31 e,
32 )
35def fix_suffix(
36 path: Path, mime_type: str | None, file_name: str | None
37) -> tuple[str, str]:
38 guessed = magic.from_file(path, mime=True)
39 if guessed == mime_type:
40 log.debug("Magic and given MIME match")
41 else:
42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type)
43 mime_type = guessed
45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False)
47 if file_name:
48 name = Path(file_name)
49 else:
50 name = Path(path.name)
52 suffix = name.suffix
54 if suffix in valid_suffix_list:
55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list)
56 return str(name), guessed
58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False)
59 if valid_suffix is None:
60 log.debug("No valid suffix found")
61 return str(name), guessed
63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix)
64 return str(name.with_suffix(valid_suffix)), guessed
67class SubclassableOnce:
68 # To allow importing everything, including plugins, during tests
69 TEST_MODE: bool = False
70 __subclasses: ClassVar[
71 dict[type["SubclassableOnce"], type["SubclassableOnce"] | None]
72 ] = {}
74 def __init_subclass__(cls, **kwargs: object) -> None:
75 if SubclassableOnce not in cls.__bases__:
76 base = SubclassableOnce.__find_direct_child(cls)
77 existing = SubclassableOnce.__subclasses.get(base)
78 if existing is not None and not SubclassableOnce.TEST_MODE:
79 raise RuntimeError("This class must be subclassed once at most!")
80 cls.__subclasses[base] = cls
81 super().__init_subclass__(**kwargs)
83 @staticmethod
84 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]:
85 for base in cls.__bases__:
86 if issubclass(base, SubclassableOnce):
87 return base
88 else:
89 raise RuntimeError("wut")
91 @classmethod
92 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]":
93 try:
94 return cls.get_unique_subclass()
95 except AttributeError:
96 return cls
98 @classmethod
99 def get_unique_subclass(cls) -> "type[SubclassableOnce]":
100 existing = SubclassableOnce.__subclasses.get(cls)
101 if existing is None:
102 raise AttributeError("Could not find any subclass", cls)
103 return existing
105 @classmethod
106 def reset_subclass(cls) -> None:
107 log.debug("Resetting subclass of %s", cls)
108 cls.__subclasses[cls] = None
111def is_valid_phone_number(phone: str | None) -> bool:
112 if phone is None:
113 return False
114 match = re.match(r"\+\d.*", phone)
115 if match is None:
116 return False
117 return match[0] == phone
120def strip_illegal_chars(s: str, repl: str = "") -> str:
121 return ILLEGAL_XML_CHARS_RE.sub(repl, s)
124# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve
125ILLEGAL = [
126 (0x00, 0x08),
127 (0x0B, 0x0C),
128 (0x0E, 0x1F),
129 (0x7F, 0x84),
130 (0x86, 0x9F),
131 (0xFDD0, 0xFDDF),
132 (0xFFFE, 0xFFFF),
133 (0x1FFFE, 0x1FFFF),
134 (0x2FFFE, 0x2FFFF),
135 (0x3FFFE, 0x3FFFF),
136 (0x4FFFE, 0x4FFFF),
137 (0x5FFFE, 0x5FFFF),
138 (0x6FFFE, 0x6FFFF),
139 (0x7FFFE, 0x7FFFF),
140 (0x8FFFE, 0x8FFFF),
141 (0x9FFFE, 0x9FFFF),
142 (0xAFFFE, 0xAFFFF),
143 (0xBFFFE, 0xBFFFF),
144 (0xCFFFE, 0xCFFFF),
145 (0xDFFFE, 0xDFFFF),
146 (0xEFFFE, 0xEFFFF),
147 (0xFFFFE, 0xFFFFF),
148 (0x10FFFE, 0x10FFFF),
149]
151ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL]
152XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]"
153ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX)
156# from https://stackoverflow.com/a/35804945/5902284
157def addLoggingLevel(
158 levelName: str = "TRACE",
159 levelNum: int = logging.DEBUG - 5,
160 methodName: str | None = None,
161) -> None:
162 """
163 Comprehensively adds a new logging level to the `logging` module and the
164 currently configured logging class.
166 `levelName` becomes an attribute of the `logging` module with the value
167 `levelNum`. `methodName` becomes a convenience method for both `logging`
168 itself and the class returned by `logging.getLoggerClass()` (usually just
169 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
170 used.
172 To avoid accidental clobberings of existing attributes, this method will
173 raise an `AttributeError` if the level name is already an attribute of the
174 `logging` module or if the method name is already present
176 Example
177 -------
178 >>> addLoggingLevel('TRACE', logging.DEBUG - 5)
179 >>> logging.getLogger(__name__).setLevel("TRACE")
180 >>> logging.getLogger(__name__).trace('that worked')
181 >>> logging.trace('so did this')
182 >>> logging.TRACE
183 5
185 """
186 if not methodName:
187 methodName = levelName.lower()
189 if hasattr(logging, levelName):
190 log.debug(f"{levelName} already defined in logging module")
191 return
192 if hasattr(logging, methodName):
193 log.debug(f"{methodName} already defined in logging module")
194 return
195 if hasattr(logging.getLoggerClass(), methodName):
196 log.debug(f"{methodName} already defined in logger class")
197 return
199 # This method was inspired by the answers to Stack Overflow post
200 # http://stackoverflow.com/q/2183233/2988730, especially
201 # http://stackoverflow.com/a/13638084/2988730
202 def logForLevel(self, message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa
203 if self.isEnabledFor(levelNum):
204 self._log(levelNum, message, args, **kwargs)
206 def logToRoot(message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa
207 logging.log(levelNum, message, *args, **kwargs)
209 logging.addLevelName(levelNum, levelName)
210 setattr(logging, levelName, levelNum)
211 setattr(logging.getLoggerClass(), methodName, logForLevel)
212 setattr(logging, methodName, logToRoot)
215class SlidgeLogger(logging.Logger):
216 def trace(self) -> None:
217 pass
220log = logging.getLogger(__name__)
223def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None:
224 if len(resources) == 0:
225 return None
227 if len(resources) == 1:
228 return next(iter(resources.values()))
230 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True)
232 if any(r["show"] == "" for r in resources.values()):
233 # if a client is "available", we're "available"
234 show = ""
235 else:
236 for r in by_priority:
237 if r["show"]:
238 show = r["show"]
239 break
240 else:
241 raise RuntimeError()
243 # if there are different statuses, we use the highest priority one,
244 # but we ignore resources without status, even with high priority
245 status = ""
246 for r in by_priority:
247 if r["status"]:
248 status = r["status"]
249 break
251 return {
252 "show": show, # type:ignore
253 "status": status,
254 "priority": 0,
255 }
258def remove_emoji_variation_selector_16(emoji: str) -> str:
259 # this is required for compatibility with dino, and maybe other future clients?
260 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode()
263def deprecated(name: str, new: Callable): # type:ignore[no-untyped-def,type-arg] # noqa
264 # @functools.wraps
265 def wrapped(*args, **kwargs): # type:ignore[no-untyped-def] # noqa
266 warnings.warn(
267 f"{name} is deprecated. Use {new.__name__} instead",
268 category=DeprecationWarning,
269 )
270 return new(*args, **kwargs)
272 return wrapped
275T = TypeVar("T", bound=NamedTuple)
278def dict_to_named_tuple(data: dict[str, Any], cls: type[T]) -> T:
279 return cls(*(data.get(f) for f in cls._fields)) # type:ignore
282def replace_mentions(
283 text: str,
284 mentions: list[Mention] | None,
285 mapping: Callable[[Mention], str],
286) -> str:
287 if not mentions:
288 return text
290 cursor = 0
291 pieces = []
292 for mention in mentions:
293 try:
294 new_text = mapping(mention)
295 except Exception as exc:
296 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc)
297 new_text = mapping(mention.contact) # type:ignore
298 pieces.extend([text[cursor : mention.start], new_text])
299 cursor = mention.end
300 pieces.append(text[cursor:])
301 return "".join(pieces)
304TimeItWrapped = TypeVar("TimeItWrapped")
307def timeit(
308 func: Callable[..., Awaitable[TimeItWrapped]],
309) -> Callable[..., Awaitable[TimeItWrapped]]:
310 @wraps(func)
311 async def wrapped(self: object, *args: object, **kwargs: object) -> TimeItWrapped:
312 start = time()
313 r = await func(self, *args, **kwargs)
314 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) # type:ignore
315 return r
317 return wrapped
320def strip_leading_emoji(text: str) -> str:
321 if not EMOJI_LIB_AVAILABLE:
322 return text
323 words = text.split(" ")
324 # is_emoji returns False for 🛷️ for obscure reasons,
325 # purely_emoji seems better
326 if len(words) > 1 and emoji.purely_emoji(words[0]):
327 return " ".join(words[1:])
328 return text
331async def noop_coro() -> None:
332 pass
335def add_quote_prefix(text: str) -> str:
336 """
337 Return multi-line text with leading quote marks (i.e. the ">" character).
338 """
339 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()