Coverage for slidge / util / util.py: 78%
175 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-03-13 22:59 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-03-13 22:59 +0000
1import logging
2import mimetypes
3import re
4import warnings
5from collections.abc import Callable
6from functools import wraps
7from pathlib import Path
8from time import time
9from typing import Any, NamedTuple, TypeVar
11try:
12 import emoji
13except ImportError:
14 EMOJI_LIB_AVAILABLE = False
15else:
16 EMOJI_LIB_AVAILABLE = True
18from slixmpp.types import ResourceDict
20from .types import Mention
22try:
23 import magic
24except ImportError as e:
25 magic = None # type:ignore
26 logging.warning(
27 (
28 "Libmagic is not available: %s. "
29 "It's OK if you don't use fix-filename-suffix-mime-type."
30 ),
31 e,
32 )
35def fix_suffix(
36 path: Path, mime_type: str | None, file_name: str | None
37) -> tuple[str, str]:
38 guessed = magic.from_file(path, mime=True)
39 if guessed == mime_type:
40 log.debug("Magic and given MIME match")
41 else:
42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type)
43 mime_type = guessed
45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False)
47 if file_name:
48 name = Path(file_name)
49 else:
50 name = Path(path.name)
52 suffix = name.suffix
54 if suffix in valid_suffix_list:
55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list)
56 return str(name), guessed
58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False)
59 if valid_suffix is None:
60 log.debug("No valid suffix found")
61 return str(name), guessed
63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix)
64 return str(name.with_suffix(valid_suffix)), guessed
67class SubclassableOnce:
68 # To allow importing everything, including plugins, during tests
69 TEST_MODE: bool = False
70 __subclasses: dict[type["SubclassableOnce"], type["SubclassableOnce"] | None] = {}
72 def __init_subclass__(cls, **kwargs: Any) -> None:
73 if SubclassableOnce not in cls.__bases__:
74 base = SubclassableOnce.__find_direct_child(cls)
75 existing = SubclassableOnce.__subclasses.get(base)
76 if existing is not None and not SubclassableOnce.TEST_MODE:
77 raise RuntimeError("This class must be subclassed once at most!")
78 cls.__subclasses[base] = cls
79 super().__init_subclass__(**kwargs)
81 @staticmethod
82 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]:
83 for base in cls.__bases__:
84 if issubclass(base, SubclassableOnce):
85 return base
86 else:
87 raise RuntimeError("wut")
89 @classmethod
90 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]":
91 try:
92 return cls.get_unique_subclass()
93 except AttributeError:
94 return cls
96 @classmethod
97 def get_unique_subclass(cls) -> "type[SubclassableOnce]":
98 existing = SubclassableOnce.__subclasses.get(cls)
99 if existing is None:
100 raise AttributeError("Could not find any subclass", cls)
101 return existing
103 @classmethod
104 def reset_subclass(cls) -> None:
105 log.debug("Resetting subclass of %s", cls)
106 cls.__subclasses[cls] = None
109def is_valid_phone_number(phone: str | None):
110 if phone is None:
111 return False
112 match = re.match(r"\+\d.*", phone)
113 if match is None:
114 return False
115 return match[0] == phone
118def strip_illegal_chars(s: str, repl: str = "") -> str:
119 return ILLEGAL_XML_CHARS_RE.sub(repl, s)
122# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve
123ILLEGAL = [
124 (0x00, 0x08),
125 (0x0B, 0x0C),
126 (0x0E, 0x1F),
127 (0x7F, 0x84),
128 (0x86, 0x9F),
129 (0xFDD0, 0xFDDF),
130 (0xFFFE, 0xFFFF),
131 (0x1FFFE, 0x1FFFF),
132 (0x2FFFE, 0x2FFFF),
133 (0x3FFFE, 0x3FFFF),
134 (0x4FFFE, 0x4FFFF),
135 (0x5FFFE, 0x5FFFF),
136 (0x6FFFE, 0x6FFFF),
137 (0x7FFFE, 0x7FFFF),
138 (0x8FFFE, 0x8FFFF),
139 (0x9FFFE, 0x9FFFF),
140 (0xAFFFE, 0xAFFFF),
141 (0xBFFFE, 0xBFFFF),
142 (0xCFFFE, 0xCFFFF),
143 (0xDFFFE, 0xDFFFF),
144 (0xEFFFE, 0xEFFFF),
145 (0xFFFFE, 0xFFFFF),
146 (0x10FFFE, 0x10FFFF),
147]
149ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL]
150XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]"
151ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX)
154# from https://stackoverflow.com/a/35804945/5902284
155def addLoggingLevel(
156 levelName: str = "TRACE", levelNum: int = logging.DEBUG - 5, methodName=None
157) -> None:
158 """
159 Comprehensively adds a new logging level to the `logging` module and the
160 currently configured logging class.
162 `levelName` becomes an attribute of the `logging` module with the value
163 `levelNum`. `methodName` becomes a convenience method for both `logging`
164 itself and the class returned by `logging.getLoggerClass()` (usually just
165 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
166 used.
168 To avoid accidental clobberings of existing attributes, this method will
169 raise an `AttributeError` if the level name is already an attribute of the
170 `logging` module or if the method name is already present
172 Example
173 -------
174 >>> addLoggingLevel('TRACE', logging.DEBUG - 5)
175 >>> logging.getLogger(__name__).setLevel("TRACE")
176 >>> logging.getLogger(__name__).trace('that worked')
177 >>> logging.trace('so did this')
178 >>> logging.TRACE
179 5
181 """
182 if not methodName:
183 methodName = levelName.lower()
185 if hasattr(logging, levelName):
186 log.debug(f"{levelName} already defined in logging module")
187 return
188 if hasattr(logging, methodName):
189 log.debug(f"{methodName} already defined in logging module")
190 return
191 if hasattr(logging.getLoggerClass(), methodName):
192 log.debug(f"{methodName} already defined in logger class")
193 return
195 # This method was inspired by the answers to Stack Overflow post
196 # http://stackoverflow.com/q/2183233/2988730, especially
197 # http://stackoverflow.com/a/13638084/2988730
198 def logForLevel(self, message, *args, **kwargs) -> None:
199 if self.isEnabledFor(levelNum):
200 self._log(levelNum, message, args, **kwargs)
202 def logToRoot(message, *args, **kwargs) -> None:
203 logging.log(levelNum, message, *args, **kwargs)
205 logging.addLevelName(levelNum, levelName)
206 setattr(logging, levelName, levelNum)
207 setattr(logging.getLoggerClass(), methodName, logForLevel)
208 setattr(logging, methodName, logToRoot)
211class SlidgeLogger(logging.Logger):
212 def trace(self) -> None:
213 pass
216log = logging.getLogger(__name__)
219def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None:
220 if len(resources) == 0:
221 return None
223 if len(resources) == 1:
224 return next(iter(resources.values()))
226 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True)
228 if any(r["show"] == "" for r in resources.values()):
229 # if a client is "available", we're "available"
230 show = ""
231 else:
232 for r in by_priority:
233 if r["show"]:
234 show = r["show"]
235 break
236 else:
237 raise RuntimeError()
239 # if there are different statuses, we use the highest priority one,
240 # but we ignore resources without status, even with high priority
241 status = ""
242 for r in by_priority:
243 if r["status"]:
244 status = r["status"]
245 break
247 return {
248 "show": show, # type:ignore
249 "status": status,
250 "priority": 0,
251 }
254def remove_emoji_variation_selector_16(emoji: str):
255 # this is required for compatibility with dino, and maybe other future clients?
256 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode()
259def deprecated(name: str, new: Callable):
260 # @functools.wraps
261 def wrapped(*args, **kwargs):
262 warnings.warn(
263 f"{name} is deprecated. Use {new.__name__} instead",
264 category=DeprecationWarning,
265 )
266 return new(*args, **kwargs)
268 return wrapped
271T = TypeVar("T", bound=NamedTuple)
274def dict_to_named_tuple(data: dict, cls: type[T]) -> T:
275 return cls(*(data.get(f) for f in cls._fields)) # type:ignore
278def replace_mentions(
279 text: str,
280 mentions: list[Mention] | None,
281 mapping: Callable[[Mention], str],
282):
283 if not mentions:
284 return text
286 cursor = 0
287 pieces = []
288 for mention in mentions:
289 try:
290 new_text = mapping(mention)
291 except Exception as exc:
292 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc)
293 new_text = mapping(mention.contact) # type:ignore
294 pieces.extend([text[cursor : mention.start], new_text])
295 cursor = mention.end
296 pieces.append(text[cursor:])
297 return "".join(pieces)
300def timeit(func):
301 @wraps(func)
302 async def wrapped(self, *args, **kwargs):
303 start = time()
304 r = await func(self, *args, **kwargs)
305 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000))
306 return r
308 return wrapped
311def strip_leading_emoji(text: str) -> str:
312 if not EMOJI_LIB_AVAILABLE:
313 return text
314 words = text.split(" ")
315 # is_emoji returns False for 🛷️ for obscure reasons,
316 # purely_emoji seems better
317 if len(words) > 1 and emoji.purely_emoji(words[0]):
318 return " ".join(words[1:])
319 return text
322async def noop_coro() -> None:
323 pass
326def add_quote_prefix(text: str):
327 """
328 Return multi-line text with leading quote marks (i.e. the ">" character).
329 """
330 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()