Coverage for slidge / util / util.py: 79%
174 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
1import logging
2import mimetypes
3import re
4import warnings
5from collections.abc import Awaitable, Callable
6from functools import wraps
7from pathlib import Path
8from time import time
9from typing import Any, ClassVar, NamedTuple, TypeVar
11try:
12 import emoji
13except ImportError:
14 EMOJI_LIB_AVAILABLE = False
15else:
16 EMOJI_LIB_AVAILABLE = True
18from slixmpp.types import ResourceDict
20from .types import Mention
22try:
23 import magic
24except ImportError as e:
25 magic = None # type:ignore
26 logging.warning(
27 (
28 "Libmagic is not available: %s. "
29 "It's OK if you don't use fix-filename-suffix-mime-type."
30 ),
31 e,
32 )
35def fix_suffix(
36 path: Path, mime_type: str | None, file_name: str | None
37) -> tuple[str, str]:
38 guessed = magic.from_file(path, mime=True)
39 if guessed == mime_type:
40 log.debug("Magic and given MIME match")
41 else:
42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type)
43 mime_type = guessed
45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False)
47 name = Path(file_name) if file_name else Path(path.name)
49 suffix = name.suffix
51 if suffix in valid_suffix_list:
52 log.debug("Suffix %s is in %s", suffix, valid_suffix_list)
53 return str(name), guessed
55 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False)
56 if valid_suffix is None:
57 log.debug("No valid suffix found")
58 return str(name), guessed
60 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix)
61 return str(name.with_suffix(valid_suffix)), guessed
64class SubclassableOnce:
65 # To allow importing everything, including plugins, during tests
66 TEST_MODE: bool = False
67 __subclasses: ClassVar[
68 dict[type["SubclassableOnce"], type["SubclassableOnce"] | None]
69 ] = {}
71 def __init_subclass__(cls, **kwargs: object) -> None:
72 if SubclassableOnce not in cls.__bases__:
73 base = SubclassableOnce.__find_direct_child(cls)
74 existing = SubclassableOnce.__subclasses.get(base)
75 if existing is not None and not SubclassableOnce.TEST_MODE:
76 raise RuntimeError("This class must be subclassed once at most!")
77 cls.__subclasses[base] = cls
78 super().__init_subclass__(**kwargs)
80 @staticmethod
81 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]:
82 for base in cls.__bases__:
83 if issubclass(base, SubclassableOnce):
84 return base
85 else:
86 raise RuntimeError("wut")
88 @classmethod
89 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]":
90 try:
91 return cls.get_unique_subclass()
92 except AttributeError:
93 return cls
95 @classmethod
96 def get_unique_subclass(cls) -> "type[SubclassableOnce]":
97 existing = SubclassableOnce.__subclasses.get(cls)
98 if existing is None:
99 raise AttributeError("Could not find any subclass", cls)
100 return existing
102 @classmethod
103 def reset_subclass(cls) -> None:
104 log.debug("Resetting subclass of %s", cls)
105 cls.__subclasses[cls] = None
108def is_valid_phone_number(phone: str | None) -> bool:
109 if phone is None:
110 return False
111 match = re.match(r"\+\d.*", phone)
112 if match is None:
113 return False
114 return match[0] == phone
117def strip_illegal_chars(s: str, repl: str = "") -> str:
118 return ILLEGAL_XML_CHARS_RE.sub(repl, s)
121# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve
122ILLEGAL = [
123 (0x00, 0x08),
124 (0x0B, 0x0C),
125 (0x0E, 0x1F),
126 (0x7F, 0x84),
127 (0x86, 0x9F),
128 (0xFDD0, 0xFDDF),
129 (0xFFFE, 0xFFFF),
130 (0x1FFFE, 0x1FFFF),
131 (0x2FFFE, 0x2FFFF),
132 (0x3FFFE, 0x3FFFF),
133 (0x4FFFE, 0x4FFFF),
134 (0x5FFFE, 0x5FFFF),
135 (0x6FFFE, 0x6FFFF),
136 (0x7FFFE, 0x7FFFF),
137 (0x8FFFE, 0x8FFFF),
138 (0x9FFFE, 0x9FFFF),
139 (0xAFFFE, 0xAFFFF),
140 (0xBFFFE, 0xBFFFF),
141 (0xCFFFE, 0xCFFFF),
142 (0xDFFFE, 0xDFFFF),
143 (0xEFFFE, 0xEFFFF),
144 (0xFFFFE, 0xFFFFF),
145 (0x10FFFE, 0x10FFFF),
146]
148ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL]
149XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]"
150ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX)
153# from https://stackoverflow.com/a/35804945/5902284
154def addLoggingLevel(
155 levelName: str = "TRACE",
156 levelNum: int = logging.DEBUG - 5,
157 methodName: str | None = None,
158) -> None:
159 """
160 Comprehensively adds a new logging level to the `logging` module and the
161 currently configured logging class.
163 `levelName` becomes an attribute of the `logging` module with the value
164 `levelNum`. `methodName` becomes a convenience method for both `logging`
165 itself and the class returned by `logging.getLoggerClass()` (usually just
166 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
167 used.
169 To avoid accidental clobberings of existing attributes, this method will
170 raise an `AttributeError` if the level name is already an attribute of the
171 `logging` module or if the method name is already present
173 Example
174 -------
175 >>> addLoggingLevel('TRACE', logging.DEBUG - 5)
176 >>> logging.getLogger(__name__).setLevel("TRACE")
177 >>> logging.getLogger(__name__).trace('that worked')
178 >>> logging.trace('so did this')
179 >>> logging.TRACE
180 5
182 """
183 if not methodName:
184 methodName = levelName.lower()
186 if hasattr(logging, levelName):
187 log.debug(f"{levelName} already defined in logging module")
188 return
189 if hasattr(logging, methodName):
190 log.debug(f"{methodName} already defined in logging module")
191 return
192 if hasattr(logging.getLoggerClass(), methodName):
193 log.debug(f"{methodName} already defined in logger class")
194 return
196 # This method was inspired by the answers to Stack Overflow post
197 # http://stackoverflow.com/q/2183233/2988730, especially
198 # http://stackoverflow.com/a/13638084/2988730
199 def logForLevel(self, message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa
200 if self.isEnabledFor(levelNum):
201 self._log(levelNum, message, args, **kwargs)
203 def logToRoot(message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa
204 logging.log(levelNum, message, *args, **kwargs)
206 logging.addLevelName(levelNum, levelName)
207 setattr(logging, levelName, levelNum)
208 setattr(logging.getLoggerClass(), methodName, logForLevel)
209 setattr(logging, methodName, logToRoot)
212class SlidgeLogger(logging.Logger):
213 def trace(self) -> None:
214 pass
217log = logging.getLogger(__name__)
220def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None:
221 if len(resources) == 0:
222 return None
224 if len(resources) == 1:
225 return next(iter(resources.values()))
227 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True)
229 if any(r["show"] == "" for r in resources.values()):
230 # if a client is "available", we're "available"
231 show = ""
232 else:
233 for r in by_priority:
234 if r["show"]:
235 show = r["show"]
236 break
237 else:
238 raise RuntimeError()
240 # if there are different statuses, we use the highest priority one,
241 # but we ignore resources without status, even with high priority
242 status = ""
243 for r in by_priority:
244 if r["status"]:
245 status = r["status"]
246 break
248 return {
249 "show": show, # type:ignore
250 "status": status,
251 "priority": 0,
252 }
255def remove_emoji_variation_selector_16(emoji: str) -> str:
256 # this is required for compatibility with dino, and maybe other future clients?
257 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode()
260def deprecated(name: str, new: Callable): # type:ignore[no-untyped-def,type-arg] # noqa
261 # @functools.wraps
262 def wrapped(*args, **kwargs): # type:ignore[no-untyped-def] # noqa
263 warnings.warn(
264 f"{name} is deprecated. Use {new.__name__} instead",
265 category=DeprecationWarning,
266 )
267 return new(*args, **kwargs)
269 return wrapped
272T = TypeVar("T", bound=NamedTuple)
275def dict_to_named_tuple(data: dict[str, Any], cls: type[T]) -> T:
276 return cls(*(data.get(f) for f in cls._fields)) # type:ignore
279def replace_mentions(
280 text: str,
281 mentions: list[Mention] | None,
282 mapping: Callable[[Mention], str],
283) -> str:
284 if not mentions:
285 return text
287 cursor = 0
288 pieces = []
289 for mention in mentions:
290 try:
291 new_text = mapping(mention)
292 except Exception as exc:
293 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc)
294 new_text = mapping(mention.contact) # type:ignore
295 pieces.extend([text[cursor : mention.start], new_text])
296 cursor = mention.end
297 pieces.append(text[cursor:])
298 return "".join(pieces)
301TimeItWrapped = TypeVar("TimeItWrapped")
304def timeit(
305 func: Callable[..., Awaitable[TimeItWrapped]],
306) -> Callable[..., Awaitable[TimeItWrapped]]:
307 @wraps(func)
308 async def wrapped(self: object, *args: object, **kwargs: object) -> TimeItWrapped:
309 start = time()
310 r = await func(self, *args, **kwargs)
311 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) # type:ignore
312 return r
314 return wrapped
317def strip_leading_emoji(text: str) -> str:
318 if not EMOJI_LIB_AVAILABLE:
319 return text
320 words = text.split(" ")
321 # is_emoji returns False for 🛷️ for obscure reasons,
322 # purely_emoji seems better
323 if len(words) > 1 and emoji.purely_emoji(words[0]):
324 return " ".join(words[1:])
325 return text
328async def noop_coro() -> None:
329 pass
332def add_quote_prefix(text: str) -> str:
333 """
334 Return multi-line text with leading quote marks (i.e. the ">" character).
335 """
336 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()