Coverage for slidge / util / util.py: 80%

173 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 04:38 +0000

1import logging 

2import mimetypes 

3import re 

4from collections.abc import Callable, Collection, Coroutine 

5from functools import wraps 

6from pathlib import Path 

7from time import time 

8from typing import ( 

9 Any, 

10 ClassVar, 

11 Concatenate, 

12 NamedTuple, 

13 ParamSpec, 

14 Protocol, 

15 TypeVar, 

16) 

17 

18try: 

19 import emoji 

20except ImportError: 

21 EMOJI_LIB_AVAILABLE = False 

22else: 

23 EMOJI_LIB_AVAILABLE = True 

24 

25from slixmpp.types import ExtPresenceShows, ResourceDict 

26 

27from .types import Mention 

28 

29try: 

30 import magic 

31except ImportError as e: 

32 magic = None # type:ignore 

33 logging.warning( 

34 ( 

35 "Libmagic is not available: %s. " 

36 "It's OK if you don't use fix-filename-suffix-mime-type." 

37 ), 

38 e, 

39 ) 

40 

41 

42def fix_suffix( 

43 path: Path, mime_type: str | None, file_name: str | None 

44) -> tuple[str, str]: 

45 guessed = magic.from_file(path, mime=True) 

46 if guessed == mime_type: 

47 log.debug("Magic and given MIME match") 

48 else: 

49 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

50 mime_type = guessed 

51 

52 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

53 

54 name = Path(file_name) if file_name else Path(path.name) 

55 

56 suffix = name.suffix 

57 

58 if suffix in valid_suffix_list: 

59 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

60 return str(name), guessed 

61 

62 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

63 if valid_suffix is None: 

64 log.debug("No valid suffix found") 

65 return str(name), guessed 

66 

67 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

68 return str(name.with_suffix(valid_suffix)), guessed 

69 

70 

71class SubclassableOnce: 

72 # To allow importing everything, including plugins, during tests 

73 TEST_MODE: bool = False 

74 __subclasses: ClassVar[ 

75 dict[type["SubclassableOnce"], type["SubclassableOnce"] | None] 

76 ] = {} 

77 

78 def __init_subclass__(cls, **kwargs: object) -> None: 

79 if SubclassableOnce not in cls.__bases__: 

80 base = SubclassableOnce.__find_direct_child(cls) 

81 existing = SubclassableOnce.__subclasses.get(base) 

82 if existing is not None and not SubclassableOnce.TEST_MODE: 

83 raise RuntimeError("This class must be subclassed once at most!") 

84 cls.__subclasses[base] = cls 

85 super().__init_subclass__(**kwargs) 

86 

87 @staticmethod 

88 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]: 

89 for base in cls.__bases__: 

90 if issubclass(base, SubclassableOnce): 

91 return base 

92 else: 

93 raise RuntimeError("wut") 

94 

95 @classmethod 

96 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]": 

97 try: 

98 return cls.get_unique_subclass() 

99 except AttributeError: 

100 return cls 

101 

102 @classmethod 

103 def get_unique_subclass(cls) -> "type[SubclassableOnce]": 

104 existing = SubclassableOnce.__subclasses.get(cls) 

105 if existing is None: 

106 raise AttributeError("Could not find any subclass", cls) 

107 return existing 

108 

109 @classmethod 

110 def reset_subclass(cls) -> None: 

111 log.debug("Resetting subclass of %s", cls) 

112 cls.__subclasses[cls] = None 

113 

114 

115def is_valid_phone_number(phone: str | None) -> bool: 

116 if phone is None: 

117 return False 

118 match = re.match(r"\+\d.*", phone) 

119 if match is None: 

120 return False 

121 return match[0] == phone 

122 

123 

124def strip_illegal_chars(s: str, repl: str = "") -> str: 

125 return ILLEGAL_XML_CHARS_RE.sub(repl, s) 

126 

127 

128# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

129ILLEGAL = [ 

130 (0x00, 0x08), 

131 (0x0B, 0x0C), 

132 (0x0E, 0x1F), 

133 (0x7F, 0x84), 

134 (0x86, 0x9F), 

135 (0xFDD0, 0xFDDF), 

136 (0xFFFE, 0xFFFF), 

137 (0x1FFFE, 0x1FFFF), 

138 (0x2FFFE, 0x2FFFF), 

139 (0x3FFFE, 0x3FFFF), 

140 (0x4FFFE, 0x4FFFF), 

141 (0x5FFFE, 0x5FFFF), 

142 (0x6FFFE, 0x6FFFF), 

143 (0x7FFFE, 0x7FFFF), 

144 (0x8FFFE, 0x8FFFF), 

145 (0x9FFFE, 0x9FFFF), 

146 (0xAFFFE, 0xAFFFF), 

147 (0xBFFFE, 0xBFFFF), 

148 (0xCFFFE, 0xCFFFF), 

149 (0xDFFFE, 0xDFFFF), 

150 (0xEFFFE, 0xEFFFF), 

151 (0xFFFFE, 0xFFFFF), 

152 (0x10FFFE, 0x10FFFF), 

153] 

154 

155ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

156XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

157ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

158 

159 

160# from https://stackoverflow.com/a/35804945/5902284 

161def addLoggingLevel( 

162 levelName: str = "TRACE", 

163 levelNum: int = logging.DEBUG - 5, 

164 methodName: str | None = None, 

165) -> None: 

166 """ 

167 Comprehensively adds a new logging level to the `logging` module and the 

168 currently configured logging class. 

169 

170 `levelName` becomes an attribute of the `logging` module with the value 

171 `levelNum`. `methodName` becomes a convenience method for both `logging` 

172 itself and the class returned by `logging.getLoggerClass()` (usually just 

173 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

174 used. 

175 

176 To avoid accidental clobberings of existing attributes, this method will 

177 raise an `AttributeError` if the level name is already an attribute of the 

178 `logging` module or if the method name is already present 

179 

180 Example 

181 ------- 

182 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

183 >>> logging.getLogger(__name__).setLevel("TRACE") 

184 >>> logging.getLogger(__name__).trace('that worked') 

185 >>> logging.trace('so did this') 

186 >>> logging.TRACE 

187 5 

188 

189 """ 

190 if not methodName: 

191 methodName = levelName.lower() 

192 

193 if hasattr(logging, levelName): 

194 log.debug(f"{levelName} already defined in logging module") 

195 return 

196 if hasattr(logging, methodName): 

197 log.debug(f"{methodName} already defined in logging module") 

198 return 

199 if hasattr(logging.getLoggerClass(), methodName): 

200 log.debug(f"{methodName} already defined in logger class") 

201 return 

202 

203 # This method was inspired by the answers to Stack Overflow post 

204 # http://stackoverflow.com/q/2183233/2988730, especially 

205 # http://stackoverflow.com/a/13638084/2988730 

206 def logForLevel(self, message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa 

207 if self.isEnabledFor(levelNum): 

208 self._log(levelNum, message, args, **kwargs) 

209 

210 def logToRoot(message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa 

211 logging.log(levelNum, message, *args, **kwargs) 

212 

213 logging.addLevelName(levelNum, levelName) 

214 setattr(logging, levelName, levelNum) 

215 setattr(logging.getLoggerClass(), methodName, logForLevel) 

216 setattr(logging, methodName, logToRoot) 

217 

218 

219class SlidgeLogger(logging.Logger): 

220 def trace(self) -> None: 

221 pass 

222 

223 

224log = logging.getLogger(__name__) 

225 

226 

227def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None: 

228 if len(resources) == 0: 

229 return None 

230 

231 if len(resources) == 1: 

232 return next(iter(resources.values())) 

233 

234 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

235 

236 if any(r["show"] == "" for r in resources.values()): 

237 # if a client is "available", we're "available" 

238 show: ExtPresenceShows = "" 

239 else: 

240 for r in by_priority: 

241 if r["show"]: 

242 show = r["show"] 

243 break 

244 else: 

245 raise RuntimeError() 

246 

247 # if there are different statuses, we use the highest priority one, 

248 # but we ignore resources without status, even with high priority 

249 status = "" 

250 for r in by_priority: 

251 if r["status"]: 

252 status = r["status"] 

253 break 

254 

255 return { 

256 "show": show, 

257 "status": status, 

258 "priority": 0, 

259 } 

260 

261 

262def remove_emoji_variation_selector_16(emoji: str) -> str: 

263 # this is required for compatibility with dino, and maybe other future clients? 

264 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

265 

266 

267NamedTupleT = TypeVar("NamedTupleT", bound=NamedTuple) 

268 

269 

270def dict_to_named_tuple(data: dict[str, Any], cls: type[NamedTupleT]) -> NamedTupleT: 

271 return cls(*(data.get(f) for f in cls._fields)) # type:ignore[arg-type] 

272 

273 

274def replace_mentions( 

275 text: str, 

276 mentions: Collection[Mention] | None, 

277 mapping: Callable[[Mention], str], 

278) -> str: 

279 if not mentions: 

280 return text 

281 

282 cursor = 0 

283 pieces = [] 

284 for mention in mentions: 

285 try: 

286 new_text = mapping(mention) 

287 except Exception as exc: 

288 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc) 

289 new_text = mapping(mention.contact) # type:ignore 

290 pieces.extend([text[cursor : mention.start], new_text]) 

291 cursor = mention.end 

292 pieces.append(text[cursor:]) 

293 return "".join(pieces) 

294 

295 

296class HasLogger(Protocol): 

297 log: logging.Logger 

298 

299 

300P = ParamSpec("P") 

301T = TypeVar("T") 

302Self = TypeVar("Self", bound=HasLogger) 

303TimeItWrapped = Callable[Concatenate[Self, P], Coroutine[Any, Any, T]] 

304 

305 

306def timeit(func: TimeItWrapped[Self, P, T]) -> TimeItWrapped[Self, P, T]: 

307 @wraps(func) 

308 async def wrapped(self: Self, /, *args: P.args, **kwargs: P.kwargs) -> T: 

309 start = time() 

310 r = await func(self, *args, **kwargs) 

311 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) 

312 return r 

313 

314 return wrapped 

315 

316 

317def strip_leading_emoji(text: str) -> str: 

318 if not EMOJI_LIB_AVAILABLE: 

319 return text 

320 words = text.split(" ") 

321 # is_emoji returns False for 🛷️ for obscure reasons, 

322 # purely_emoji seems better 

323 if len(words) > 1 and emoji.purely_emoji(words[0]): 

324 return " ".join(words[1:]) 

325 return text 

326 

327 

328async def noop_coro() -> None: 

329 pass 

330 

331 

332def add_quote_prefix(text: str) -> str: 

333 """ 

334 Return multi-line text with leading quote marks (i.e. the ">" character). 

335 """ 

336 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()