Coverage for slidge/util/util.py: 74%
174 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1import logging
2import mimetypes
3import re
4import subprocess
5import warnings
6from abc import ABCMeta
7from functools import wraps
8from pathlib import Path
9from time import time
10from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Type, TypeVar
12try:
13 import emoji
14except ImportError:
15 EMOJI_LIB_AVAILABLE = False
16else:
17 EMOJI_LIB_AVAILABLE = True
19from .types import Mention, ResourceDict
21if TYPE_CHECKING:
22 from ..contact.contact import LegacyContact
24try:
25 import magic
26except ImportError as e:
27 magic = None # type:ignore
28 logging.warning(
29 (
30 "Libmagic is not available: %s. "
31 "It's OK if you don't use fix-filename-suffix-mime-type."
32 ),
33 e,
34 )
37def fix_suffix(path: Path, mime_type: Optional[str], file_name: Optional[str]):
38 guessed = magic.from_file(path, mime=True)
39 if guessed == mime_type:
40 log.debug("Magic and given MIME match")
41 else:
42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type)
43 mime_type = guessed
45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False)
47 if file_name:
48 name = Path(file_name)
49 else:
50 name = Path(path.name)
52 suffix = name.suffix
54 if suffix in valid_suffix_list:
55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list)
56 return name
58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False)
59 if valid_suffix is None:
60 log.debug("No valid suffix found")
61 return name
63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix)
64 return name.with_suffix(valid_suffix)
67class SubclassableOnce(type):
68 TEST_MODE = False # To allow importing everything, including plugins, during tests
70 def __init__(cls, name, bases, dct):
71 for b in bases:
72 if type(b) in (SubclassableOnce, ABCSubclassableOnceAtMost):
73 if hasattr(b, "_subclass") and not cls.TEST_MODE:
74 raise RuntimeError(
75 "This class must be subclassed once at most!",
76 cls,
77 name,
78 bases,
79 dct,
80 )
81 else:
82 log.debug("Setting %s as subclass for %s", cls, b)
83 b._subclass = cls
85 super().__init__(name, bases, dct)
87 def get_self_or_unique_subclass(cls):
88 try:
89 return cls.get_unique_subclass()
90 except AttributeError:
91 return cls
93 def get_unique_subclass(cls):
94 r = getattr(cls, "_subclass", None)
95 if r is None:
96 raise AttributeError("Could not find any subclass", cls)
97 return r
99 def reset_subclass(cls):
100 try:
101 log.debug("Resetting subclass of %s", cls)
102 delattr(cls, "_subclass")
103 except AttributeError:
104 log.debug("No subclass were registered for %s", cls)
107class ABCSubclassableOnceAtMost(ABCMeta, SubclassableOnce):
108 pass
111def is_valid_phone_number(phone: Optional[str]):
112 if phone is None:
113 return False
114 match = re.match(r"\+\d.*", phone)
115 if match is None:
116 return False
117 return match[0] == phone
120def strip_illegal_chars(s: str):
121 return ILLEGAL_XML_CHARS_RE.sub("", s)
124# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve
125ILLEGAL = [
126 (0x00, 0x08),
127 (0x0B, 0x0C),
128 (0x0E, 0x1F),
129 (0x7F, 0x84),
130 (0x86, 0x9F),
131 (0xFDD0, 0xFDDF),
132 (0xFFFE, 0xFFFF),
133 (0x1FFFE, 0x1FFFF),
134 (0x2FFFE, 0x2FFFF),
135 (0x3FFFE, 0x3FFFF),
136 (0x4FFFE, 0x4FFFF),
137 (0x5FFFE, 0x5FFFF),
138 (0x6FFFE, 0x6FFFF),
139 (0x7FFFE, 0x7FFFF),
140 (0x8FFFE, 0x8FFFF),
141 (0x9FFFE, 0x9FFFF),
142 (0xAFFFE, 0xAFFFF),
143 (0xBFFFE, 0xBFFFF),
144 (0xCFFFE, 0xCFFFF),
145 (0xDFFFE, 0xDFFFF),
146 (0xEFFFE, 0xEFFFF),
147 (0xFFFFE, 0xFFFFF),
148 (0x10FFFE, 0x10FFFF),
149]
151ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL]
152XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]"
153ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX)
156# from https://stackoverflow.com/a/35804945/5902284
157def addLoggingLevel(
158 levelName: str = "TRACE", levelNum: int = logging.DEBUG - 5, methodName=None
159):
160 """
161 Comprehensively adds a new logging level to the `logging` module and the
162 currently configured logging class.
164 `levelName` becomes an attribute of the `logging` module with the value
165 `levelNum`. `methodName` becomes a convenience method for both `logging`
166 itself and the class returned by `logging.getLoggerClass()` (usually just
167 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is
168 used.
170 To avoid accidental clobberings of existing attributes, this method will
171 raise an `AttributeError` if the level name is already an attribute of the
172 `logging` module or if the method name is already present
174 Example
175 -------
176 >>> addLoggingLevel('TRACE', logging.DEBUG - 5)
177 >>> logging.getLogger(__name__).setLevel("TRACE")
178 >>> logging.getLogger(__name__).trace('that worked')
179 >>> logging.trace('so did this')
180 >>> logging.TRACE
181 5
183 """
184 if not methodName:
185 methodName = levelName.lower()
187 if hasattr(logging, levelName):
188 log.debug("{} already defined in logging module".format(levelName))
189 return
190 if hasattr(logging, methodName):
191 log.debug("{} already defined in logging module".format(methodName))
192 return
193 if hasattr(logging.getLoggerClass(), methodName):
194 log.debug("{} already defined in logger class".format(methodName))
195 return
197 # This method was inspired by the answers to Stack Overflow post
198 # http://stackoverflow.com/q/2183233/2988730, especially
199 # http://stackoverflow.com/a/13638084/2988730
200 def logForLevel(self, message, *args, **kwargs):
201 if self.isEnabledFor(levelNum):
202 self._log(levelNum, message, args, **kwargs)
204 def logToRoot(message, *args, **kwargs):
205 logging.log(levelNum, message, *args, **kwargs)
207 logging.addLevelName(levelNum, levelName)
208 setattr(logging, levelName, levelNum)
209 setattr(logging.getLoggerClass(), methodName, logForLevel)
210 setattr(logging, methodName, logToRoot)
213class SlidgeLogger(logging.Logger):
214 def trace(self):
215 pass
218log = logging.getLogger(__name__)
221def get_version():
222 try:
223 git = subprocess.check_output(
224 ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL
225 ).decode()
226 except (FileNotFoundError, subprocess.CalledProcessError):
227 pass
228 else:
229 return "git-" + git[:10]
231 return "NO_VERSION"
234def merge_resources(resources: dict[str, ResourceDict]) -> Optional[ResourceDict]:
235 if len(resources) == 0:
236 return None
238 if len(resources) == 1:
239 return next(iter(resources.values()))
241 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True)
243 if any(r["show"] == "" for r in resources.values()):
244 # if a client is "available", we're "available"
245 show = ""
246 else:
247 for r in by_priority:
248 if r["show"]:
249 show = r["show"]
250 break
251 else:
252 raise RuntimeError()
254 # if there are different statuses, we use the highest priority one,
255 # but we ignore resources without status, even with high priority
256 status = ""
257 for r in by_priority:
258 if r["status"]:
259 status = r["status"]
260 break
262 return {
263 "show": show, # type:ignore
264 "status": status,
265 "priority": 0,
266 }
269def remove_emoji_variation_selector_16(emoji: str):
270 # this is required for compatibility with dino, and maybe other future clients?
271 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode()
274def deprecated(name: str, new: Callable):
275 # @functools.wraps
276 def wrapped(*args, **kwargs):
277 warnings.warn(
278 f"{name} is deprecated. Use {new.__name__} instead",
279 category=DeprecationWarning,
280 )
281 return new(*args, **kwargs)
283 return wrapped
286T = TypeVar("T", bound=NamedTuple)
289def dict_to_named_tuple(data: dict, cls: Type[T]) -> T:
290 return cls(*(data.get(f) for f in cls._fields)) # type:ignore
293def replace_mentions(
294 text: str,
295 mentions: Optional[list[Mention]],
296 mapping: Callable[["LegacyContact"], str],
297):
298 if not mentions:
299 return text
301 cursor = 0
302 pieces = []
303 for mention in mentions:
304 pieces.extend([text[cursor : mention.start], mapping(mention.contact)])
305 cursor = mention.end
306 pieces.append(text[cursor:])
307 return "".join(pieces)
310def with_session(func):
311 @wraps(func)
312 async def wrapped(self, *args, **kwargs):
313 with self.xmpp.store.session():
314 return await func(self, *args, **kwargs)
316 return wrapped
319def timeit(func):
320 @wraps(func)
321 async def wrapped(self, *args, **kwargs):
322 start = time()
323 r = await func(self, *args, **kwargs)
324 self.log.info("%s took %s ms", func.__name__, round((time() - start) * 1000))
325 return r
327 return wrapped
330def strip_leading_emoji(text: str) -> str:
331 if not EMOJI_LIB_AVAILABLE:
332 return text
333 words = text.split(" ")
334 # is_emoji returns False for 🛷️ for obscure reasons,
335 # purely_emoji seems better
336 if len(words) > 1 and emoji.purely_emoji(words[0]):
337 return " ".join(words[1:])
338 return text