Coverage for slidge / util / util.py: 79%

174 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-20 19:56 +0000

1import logging 

2import mimetypes 

3import re 

4import warnings 

5from collections.abc import Awaitable, Callable 

6from functools import wraps 

7from pathlib import Path 

8from time import time 

9from typing import Any, ClassVar, NamedTuple, TypeVar 

10 

11try: 

12 import emoji 

13except ImportError: 

14 EMOJI_LIB_AVAILABLE = False 

15else: 

16 EMOJI_LIB_AVAILABLE = True 

17 

18from slixmpp.types import ResourceDict 

19 

20from .types import Mention 

21 

22try: 

23 import magic 

24except ImportError as e: 

25 magic = None # type:ignore 

26 logging.warning( 

27 ( 

28 "Libmagic is not available: %s. " 

29 "It's OK if you don't use fix-filename-suffix-mime-type." 

30 ), 

31 e, 

32 ) 

33 

34 

35def fix_suffix( 

36 path: Path, mime_type: str | None, file_name: str | None 

37) -> tuple[str, str]: 

38 guessed = magic.from_file(path, mime=True) 

39 if guessed == mime_type: 

40 log.debug("Magic and given MIME match") 

41 else: 

42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

43 mime_type = guessed 

44 

45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

46 

47 name = Path(file_name) if file_name else Path(path.name) 

48 

49 suffix = name.suffix 

50 

51 if suffix in valid_suffix_list: 

52 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

53 return str(name), guessed 

54 

55 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

56 if valid_suffix is None: 

57 log.debug("No valid suffix found") 

58 return str(name), guessed 

59 

60 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

61 return str(name.with_suffix(valid_suffix)), guessed 

62 

63 

64class SubclassableOnce: 

65 # To allow importing everything, including plugins, during tests 

66 TEST_MODE: bool = False 

67 __subclasses: ClassVar[ 

68 dict[type["SubclassableOnce"], type["SubclassableOnce"] | None] 

69 ] = {} 

70 

71 def __init_subclass__(cls, **kwargs: object) -> None: 

72 if SubclassableOnce not in cls.__bases__: 

73 base = SubclassableOnce.__find_direct_child(cls) 

74 existing = SubclassableOnce.__subclasses.get(base) 

75 if existing is not None and not SubclassableOnce.TEST_MODE: 

76 raise RuntimeError("This class must be subclassed once at most!") 

77 cls.__subclasses[base] = cls 

78 super().__init_subclass__(**kwargs) 

79 

80 @staticmethod 

81 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]: 

82 for base in cls.__bases__: 

83 if issubclass(base, SubclassableOnce): 

84 return base 

85 else: 

86 raise RuntimeError("wut") 

87 

88 @classmethod 

89 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]": 

90 try: 

91 return cls.get_unique_subclass() 

92 except AttributeError: 

93 return cls 

94 

95 @classmethod 

96 def get_unique_subclass(cls) -> "type[SubclassableOnce]": 

97 existing = SubclassableOnce.__subclasses.get(cls) 

98 if existing is None: 

99 raise AttributeError("Could not find any subclass", cls) 

100 return existing 

101 

102 @classmethod 

103 def reset_subclass(cls) -> None: 

104 log.debug("Resetting subclass of %s", cls) 

105 cls.__subclasses[cls] = None 

106 

107 

108def is_valid_phone_number(phone: str | None) -> bool: 

109 if phone is None: 

110 return False 

111 match = re.match(r"\+\d.*", phone) 

112 if match is None: 

113 return False 

114 return match[0] == phone 

115 

116 

117def strip_illegal_chars(s: str, repl: str = "") -> str: 

118 return ILLEGAL_XML_CHARS_RE.sub(repl, s) 

119 

120 

121# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

122ILLEGAL = [ 

123 (0x00, 0x08), 

124 (0x0B, 0x0C), 

125 (0x0E, 0x1F), 

126 (0x7F, 0x84), 

127 (0x86, 0x9F), 

128 (0xFDD0, 0xFDDF), 

129 (0xFFFE, 0xFFFF), 

130 (0x1FFFE, 0x1FFFF), 

131 (0x2FFFE, 0x2FFFF), 

132 (0x3FFFE, 0x3FFFF), 

133 (0x4FFFE, 0x4FFFF), 

134 (0x5FFFE, 0x5FFFF), 

135 (0x6FFFE, 0x6FFFF), 

136 (0x7FFFE, 0x7FFFF), 

137 (0x8FFFE, 0x8FFFF), 

138 (0x9FFFE, 0x9FFFF), 

139 (0xAFFFE, 0xAFFFF), 

140 (0xBFFFE, 0xBFFFF), 

141 (0xCFFFE, 0xCFFFF), 

142 (0xDFFFE, 0xDFFFF), 

143 (0xEFFFE, 0xEFFFF), 

144 (0xFFFFE, 0xFFFFF), 

145 (0x10FFFE, 0x10FFFF), 

146] 

147 

148ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

149XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

150ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

151 

152 

153# from https://stackoverflow.com/a/35804945/5902284 

154def addLoggingLevel( 

155 levelName: str = "TRACE", 

156 levelNum: int = logging.DEBUG - 5, 

157 methodName: str | None = None, 

158) -> None: 

159 """ 

160 Comprehensively adds a new logging level to the `logging` module and the 

161 currently configured logging class. 

162 

163 `levelName` becomes an attribute of the `logging` module with the value 

164 `levelNum`. `methodName` becomes a convenience method for both `logging` 

165 itself and the class returned by `logging.getLoggerClass()` (usually just 

166 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

167 used. 

168 

169 To avoid accidental clobberings of existing attributes, this method will 

170 raise an `AttributeError` if the level name is already an attribute of the 

171 `logging` module or if the method name is already present 

172 

173 Example 

174 ------- 

175 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

176 >>> logging.getLogger(__name__).setLevel("TRACE") 

177 >>> logging.getLogger(__name__).trace('that worked') 

178 >>> logging.trace('so did this') 

179 >>> logging.TRACE 

180 5 

181 

182 """ 

183 if not methodName: 

184 methodName = levelName.lower() 

185 

186 if hasattr(logging, levelName): 

187 log.debug(f"{levelName} already defined in logging module") 

188 return 

189 if hasattr(logging, methodName): 

190 log.debug(f"{methodName} already defined in logging module") 

191 return 

192 if hasattr(logging.getLoggerClass(), methodName): 

193 log.debug(f"{methodName} already defined in logger class") 

194 return 

195 

196 # This method was inspired by the answers to Stack Overflow post 

197 # http://stackoverflow.com/q/2183233/2988730, especially 

198 # http://stackoverflow.com/a/13638084/2988730 

199 def logForLevel(self, message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa 

200 if self.isEnabledFor(levelNum): 

201 self._log(levelNum, message, args, **kwargs) 

202 

203 def logToRoot(message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa 

204 logging.log(levelNum, message, *args, **kwargs) 

205 

206 logging.addLevelName(levelNum, levelName) 

207 setattr(logging, levelName, levelNum) 

208 setattr(logging.getLoggerClass(), methodName, logForLevel) 

209 setattr(logging, methodName, logToRoot) 

210 

211 

212class SlidgeLogger(logging.Logger): 

213 def trace(self) -> None: 

214 pass 

215 

216 

217log = logging.getLogger(__name__) 

218 

219 

220def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None: 

221 if len(resources) == 0: 

222 return None 

223 

224 if len(resources) == 1: 

225 return next(iter(resources.values())) 

226 

227 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

228 

229 if any(r["show"] == "" for r in resources.values()): 

230 # if a client is "available", we're "available" 

231 show = "" 

232 else: 

233 for r in by_priority: 

234 if r["show"]: 

235 show = r["show"] 

236 break 

237 else: 

238 raise RuntimeError() 

239 

240 # if there are different statuses, we use the highest priority one, 

241 # but we ignore resources without status, even with high priority 

242 status = "" 

243 for r in by_priority: 

244 if r["status"]: 

245 status = r["status"] 

246 break 

247 

248 return { 

249 "show": show, # type:ignore 

250 "status": status, 

251 "priority": 0, 

252 } 

253 

254 

255def remove_emoji_variation_selector_16(emoji: str) -> str: 

256 # this is required for compatibility with dino, and maybe other future clients? 

257 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

258 

259 

260def deprecated(name: str, new: Callable): # type:ignore[no-untyped-def,type-arg] # noqa 

261 # @functools.wraps 

262 def wrapped(*args, **kwargs): # type:ignore[no-untyped-def] # noqa 

263 warnings.warn( 

264 f"{name} is deprecated. Use {new.__name__} instead", 

265 category=DeprecationWarning, 

266 ) 

267 return new(*args, **kwargs) 

268 

269 return wrapped 

270 

271 

272T = TypeVar("T", bound=NamedTuple) 

273 

274 

275def dict_to_named_tuple(data: dict[str, Any], cls: type[T]) -> T: 

276 return cls(*(data.get(f) for f in cls._fields)) # type:ignore 

277 

278 

279def replace_mentions( 

280 text: str, 

281 mentions: list[Mention] | None, 

282 mapping: Callable[[Mention], str], 

283) -> str: 

284 if not mentions: 

285 return text 

286 

287 cursor = 0 

288 pieces = [] 

289 for mention in mentions: 

290 try: 

291 new_text = mapping(mention) 

292 except Exception as exc: 

293 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc) 

294 new_text = mapping(mention.contact) # type:ignore 

295 pieces.extend([text[cursor : mention.start], new_text]) 

296 cursor = mention.end 

297 pieces.append(text[cursor:]) 

298 return "".join(pieces) 

299 

300 

301TimeItWrapped = TypeVar("TimeItWrapped") 

302 

303 

304def timeit( 

305 func: Callable[..., Awaitable[TimeItWrapped]], 

306) -> Callable[..., Awaitable[TimeItWrapped]]: 

307 @wraps(func) 

308 async def wrapped(self: object, *args: object, **kwargs: object) -> TimeItWrapped: 

309 start = time() 

310 r = await func(self, *args, **kwargs) 

311 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) # type:ignore 

312 return r 

313 

314 return wrapped 

315 

316 

317def strip_leading_emoji(text: str) -> str: 

318 if not EMOJI_LIB_AVAILABLE: 

319 return text 

320 words = text.split(" ") 

321 # is_emoji returns False for 🛷️ for obscure reasons, 

322 # purely_emoji seems better 

323 if len(words) > 1 and emoji.purely_emoji(words[0]): 

324 return " ".join(words[1:]) 

325 return text 

326 

327 

328async def noop_coro() -> None: 

329 pass 

330 

331 

332def add_quote_prefix(text: str) -> str: 

333 """ 

334 Return multi-line text with leading quote marks (i.e. the ">" character). 

335 """ 

336 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()