Coverage for slidge / util / util.py: 78%

176 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1import logging 

2import mimetypes 

3import re 

4import warnings 

5from collections.abc import Awaitable, Callable 

6from functools import wraps 

7from pathlib import Path 

8from time import time 

9from typing import Any, ClassVar, NamedTuple, TypeVar 

10 

11try: 

12 import emoji 

13except ImportError: 

14 EMOJI_LIB_AVAILABLE = False 

15else: 

16 EMOJI_LIB_AVAILABLE = True 

17 

18from slixmpp.types import ResourceDict 

19 

20from .types import Mention 

21 

22try: 

23 import magic 

24except ImportError as e: 

25 magic = None # type:ignore 

26 logging.warning( 

27 ( 

28 "Libmagic is not available: %s. " 

29 "It's OK if you don't use fix-filename-suffix-mime-type." 

30 ), 

31 e, 

32 ) 

33 

34 

35def fix_suffix( 

36 path: Path, mime_type: str | None, file_name: str | None 

37) -> tuple[str, str]: 

38 guessed = magic.from_file(path, mime=True) 

39 if guessed == mime_type: 

40 log.debug("Magic and given MIME match") 

41 else: 

42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

43 mime_type = guessed 

44 

45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

46 

47 if file_name: 

48 name = Path(file_name) 

49 else: 

50 name = Path(path.name) 

51 

52 suffix = name.suffix 

53 

54 if suffix in valid_suffix_list: 

55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

56 return str(name), guessed 

57 

58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

59 if valid_suffix is None: 

60 log.debug("No valid suffix found") 

61 return str(name), guessed 

62 

63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

64 return str(name.with_suffix(valid_suffix)), guessed 

65 

66 

67class SubclassableOnce: 

68 # To allow importing everything, including plugins, during tests 

69 TEST_MODE: bool = False 

70 __subclasses: ClassVar[ 

71 dict[type["SubclassableOnce"], type["SubclassableOnce"] | None] 

72 ] = {} 

73 

74 def __init_subclass__(cls, **kwargs: object) -> None: 

75 if SubclassableOnce not in cls.__bases__: 

76 base = SubclassableOnce.__find_direct_child(cls) 

77 existing = SubclassableOnce.__subclasses.get(base) 

78 if existing is not None and not SubclassableOnce.TEST_MODE: 

79 raise RuntimeError("This class must be subclassed once at most!") 

80 cls.__subclasses[base] = cls 

81 super().__init_subclass__(**kwargs) 

82 

83 @staticmethod 

84 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]: 

85 for base in cls.__bases__: 

86 if issubclass(base, SubclassableOnce): 

87 return base 

88 else: 

89 raise RuntimeError("wut") 

90 

91 @classmethod 

92 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]": 

93 try: 

94 return cls.get_unique_subclass() 

95 except AttributeError: 

96 return cls 

97 

98 @classmethod 

99 def get_unique_subclass(cls) -> "type[SubclassableOnce]": 

100 existing = SubclassableOnce.__subclasses.get(cls) 

101 if existing is None: 

102 raise AttributeError("Could not find any subclass", cls) 

103 return existing 

104 

105 @classmethod 

106 def reset_subclass(cls) -> None: 

107 log.debug("Resetting subclass of %s", cls) 

108 cls.__subclasses[cls] = None 

109 

110 

111def is_valid_phone_number(phone: str | None) -> bool: 

112 if phone is None: 

113 return False 

114 match = re.match(r"\+\d.*", phone) 

115 if match is None: 

116 return False 

117 return match[0] == phone 

118 

119 

120def strip_illegal_chars(s: str, repl: str = "") -> str: 

121 return ILLEGAL_XML_CHARS_RE.sub(repl, s) 

122 

123 

124# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

125ILLEGAL = [ 

126 (0x00, 0x08), 

127 (0x0B, 0x0C), 

128 (0x0E, 0x1F), 

129 (0x7F, 0x84), 

130 (0x86, 0x9F), 

131 (0xFDD0, 0xFDDF), 

132 (0xFFFE, 0xFFFF), 

133 (0x1FFFE, 0x1FFFF), 

134 (0x2FFFE, 0x2FFFF), 

135 (0x3FFFE, 0x3FFFF), 

136 (0x4FFFE, 0x4FFFF), 

137 (0x5FFFE, 0x5FFFF), 

138 (0x6FFFE, 0x6FFFF), 

139 (0x7FFFE, 0x7FFFF), 

140 (0x8FFFE, 0x8FFFF), 

141 (0x9FFFE, 0x9FFFF), 

142 (0xAFFFE, 0xAFFFF), 

143 (0xBFFFE, 0xBFFFF), 

144 (0xCFFFE, 0xCFFFF), 

145 (0xDFFFE, 0xDFFFF), 

146 (0xEFFFE, 0xEFFFF), 

147 (0xFFFFE, 0xFFFFF), 

148 (0x10FFFE, 0x10FFFF), 

149] 

150 

151ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

152XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

153ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

154 

155 

156# from https://stackoverflow.com/a/35804945/5902284 

157def addLoggingLevel( 

158 levelName: str = "TRACE", 

159 levelNum: int = logging.DEBUG - 5, 

160 methodName: str | None = None, 

161) -> None: 

162 """ 

163 Comprehensively adds a new logging level to the `logging` module and the 

164 currently configured logging class. 

165 

166 `levelName` becomes an attribute of the `logging` module with the value 

167 `levelNum`. `methodName` becomes a convenience method for both `logging` 

168 itself and the class returned by `logging.getLoggerClass()` (usually just 

169 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

170 used. 

171 

172 To avoid accidental clobberings of existing attributes, this method will 

173 raise an `AttributeError` if the level name is already an attribute of the 

174 `logging` module or if the method name is already present 

175 

176 Example 

177 ------- 

178 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

179 >>> logging.getLogger(__name__).setLevel("TRACE") 

180 >>> logging.getLogger(__name__).trace('that worked') 

181 >>> logging.trace('so did this') 

182 >>> logging.TRACE 

183 5 

184 

185 """ 

186 if not methodName: 

187 methodName = levelName.lower() 

188 

189 if hasattr(logging, levelName): 

190 log.debug(f"{levelName} already defined in logging module") 

191 return 

192 if hasattr(logging, methodName): 

193 log.debug(f"{methodName} already defined in logging module") 

194 return 

195 if hasattr(logging.getLoggerClass(), methodName): 

196 log.debug(f"{methodName} already defined in logger class") 

197 return 

198 

199 # This method was inspired by the answers to Stack Overflow post 

200 # http://stackoverflow.com/q/2183233/2988730, especially 

201 # http://stackoverflow.com/a/13638084/2988730 

202 def logForLevel(self, message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa 

203 if self.isEnabledFor(levelNum): 

204 self._log(levelNum, message, args, **kwargs) 

205 

206 def logToRoot(message, *args, **kwargs) -> None: # type:ignore[no-untyped-def] # noqa 

207 logging.log(levelNum, message, *args, **kwargs) 

208 

209 logging.addLevelName(levelNum, levelName) 

210 setattr(logging, levelName, levelNum) 

211 setattr(logging.getLoggerClass(), methodName, logForLevel) 

212 setattr(logging, methodName, logToRoot) 

213 

214 

215class SlidgeLogger(logging.Logger): 

216 def trace(self) -> None: 

217 pass 

218 

219 

220log = logging.getLogger(__name__) 

221 

222 

223def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None: 

224 if len(resources) == 0: 

225 return None 

226 

227 if len(resources) == 1: 

228 return next(iter(resources.values())) 

229 

230 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

231 

232 if any(r["show"] == "" for r in resources.values()): 

233 # if a client is "available", we're "available" 

234 show = "" 

235 else: 

236 for r in by_priority: 

237 if r["show"]: 

238 show = r["show"] 

239 break 

240 else: 

241 raise RuntimeError() 

242 

243 # if there are different statuses, we use the highest priority one, 

244 # but we ignore resources without status, even with high priority 

245 status = "" 

246 for r in by_priority: 

247 if r["status"]: 

248 status = r["status"] 

249 break 

250 

251 return { 

252 "show": show, # type:ignore 

253 "status": status, 

254 "priority": 0, 

255 } 

256 

257 

258def remove_emoji_variation_selector_16(emoji: str) -> str: 

259 # this is required for compatibility with dino, and maybe other future clients? 

260 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

261 

262 

263def deprecated(name: str, new: Callable): # type:ignore[no-untyped-def,type-arg] # noqa 

264 # @functools.wraps 

265 def wrapped(*args, **kwargs): # type:ignore[no-untyped-def] # noqa 

266 warnings.warn( 

267 f"{name} is deprecated. Use {new.__name__} instead", 

268 category=DeprecationWarning, 

269 ) 

270 return new(*args, **kwargs) 

271 

272 return wrapped 

273 

274 

275T = TypeVar("T", bound=NamedTuple) 

276 

277 

278def dict_to_named_tuple(data: dict[str, Any], cls: type[T]) -> T: 

279 return cls(*(data.get(f) for f in cls._fields)) # type:ignore 

280 

281 

282def replace_mentions( 

283 text: str, 

284 mentions: list[Mention] | None, 

285 mapping: Callable[[Mention], str], 

286) -> str: 

287 if not mentions: 

288 return text 

289 

290 cursor = 0 

291 pieces = [] 

292 for mention in mentions: 

293 try: 

294 new_text = mapping(mention) 

295 except Exception as exc: 

296 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc) 

297 new_text = mapping(mention.contact) # type:ignore 

298 pieces.extend([text[cursor : mention.start], new_text]) 

299 cursor = mention.end 

300 pieces.append(text[cursor:]) 

301 return "".join(pieces) 

302 

303 

304TimeItWrapped = TypeVar("TimeItWrapped") 

305 

306 

307def timeit( 

308 func: Callable[..., Awaitable[TimeItWrapped]], 

309) -> Callable[..., Awaitable[TimeItWrapped]]: 

310 @wraps(func) 

311 async def wrapped(self: object, *args: object, **kwargs: object) -> TimeItWrapped: 

312 start = time() 

313 r = await func(self, *args, **kwargs) 

314 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) # type:ignore 

315 return r 

316 

317 return wrapped 

318 

319 

320def strip_leading_emoji(text: str) -> str: 

321 if not EMOJI_LIB_AVAILABLE: 

322 return text 

323 words = text.split(" ") 

324 # is_emoji returns False for 🛷️ for obscure reasons, 

325 # purely_emoji seems better 

326 if len(words) > 1 and emoji.purely_emoji(words[0]): 

327 return " ".join(words[1:]) 

328 return text 

329 

330 

331async def noop_coro() -> None: 

332 pass 

333 

334 

335def add_quote_prefix(text: str) -> str: 

336 """ 

337 Return multi-line text with leading quote marks (i.e. the ">" character). 

338 """ 

339 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()