Coverage for slidge / util / util.py: 78%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-02-15 09:02 +0000

1import logging 

2import mimetypes 

3import re 

4import warnings 

5from abc import ABCMeta 

6from collections.abc import Callable 

7from functools import wraps 

8from pathlib import Path 

9from time import time 

10from typing import Any, NamedTuple, TypeVar 

11 

12try: 

13 import emoji 

14except ImportError: 

15 EMOJI_LIB_AVAILABLE = False 

16else: 

17 EMOJI_LIB_AVAILABLE = True 

18 

19from slixmpp.types import ResourceDict 

20 

21from .types import Mention 

22 

23try: 

24 import magic 

25except ImportError as e: 

26 magic = None # type:ignore 

27 logging.warning( 

28 ( 

29 "Libmagic is not available: %s. " 

30 "It's OK if you don't use fix-filename-suffix-mime-type." 

31 ), 

32 e, 

33 ) 

34 

35 

36def fix_suffix( 

37 path: Path, mime_type: str | None, file_name: str | None 

38) -> tuple[str, str]: 

39 guessed = magic.from_file(path, mime=True) 

40 if guessed == mime_type: 

41 log.debug("Magic and given MIME match") 

42 else: 

43 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

44 mime_type = guessed 

45 

46 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

47 

48 if file_name: 

49 name = Path(file_name) 

50 else: 

51 name = Path(path.name) 

52 

53 suffix = name.suffix 

54 

55 if suffix in valid_suffix_list: 

56 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

57 return str(name), guessed 

58 

59 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

60 if valid_suffix is None: 

61 log.debug("No valid suffix found") 

62 return str(name), guessed 

63 

64 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

65 return str(name.with_suffix(valid_suffix)), guessed 

66 

67 

68class SubclassableOnce(type): 

69 # To allow importing everything, including plugins, during tests 

70 TEST_MODE: bool = False 

71 

72 def __init__( 

73 cls, 

74 name: str, 

75 bases: tuple[type[Any], ...], 

76 dct: dict[str, Any], 

77 ) -> None: 

78 for b in bases: 

79 if type(b) in (SubclassableOnce, ABCSubclassableOnceAtMost): 

80 if hasattr(b, "_subclass") and not cls.TEST_MODE: 

81 raise RuntimeError( 

82 "This class must be subclassed once at most!", 

83 cls, 

84 name, 

85 bases, 

86 dct, 

87 ) 

88 else: 

89 log.debug("Setting %s as subclass for %s", cls, b) 

90 b._subclass = cls 

91 

92 super().__init__(name, bases, dct) 

93 

94 def get_self_or_unique_subclass(cls) -> "SubclassableOnce": 

95 try: 

96 return cls.get_unique_subclass() 

97 except AttributeError: 

98 return cls 

99 

100 def get_unique_subclass(cls) -> "SubclassableOnce": 

101 r = getattr(cls, "_subclass", None) 

102 if r is None: 

103 raise AttributeError("Could not find any subclass", cls) 

104 return r 

105 

106 def reset_subclass(cls) -> None: 

107 try: 

108 log.debug("Resetting subclass of %s", cls) 

109 delattr(cls, "_subclass") 

110 except AttributeError: 

111 log.debug("No subclass were registered for %s", cls) 

112 

113 

114class ABCSubclassableOnceAtMost(ABCMeta, SubclassableOnce): 

115 pass 

116 

117 

118def is_valid_phone_number(phone: str | None): 

119 if phone is None: 

120 return False 

121 match = re.match(r"\+\d.*", phone) 

122 if match is None: 

123 return False 

124 return match[0] == phone 

125 

126 

127def strip_illegal_chars(s: str, repl: str = "") -> str: 

128 return ILLEGAL_XML_CHARS_RE.sub(repl, s) 

129 

130 

131# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

132ILLEGAL = [ 

133 (0x00, 0x08), 

134 (0x0B, 0x0C), 

135 (0x0E, 0x1F), 

136 (0x7F, 0x84), 

137 (0x86, 0x9F), 

138 (0xFDD0, 0xFDDF), 

139 (0xFFFE, 0xFFFF), 

140 (0x1FFFE, 0x1FFFF), 

141 (0x2FFFE, 0x2FFFF), 

142 (0x3FFFE, 0x3FFFF), 

143 (0x4FFFE, 0x4FFFF), 

144 (0x5FFFE, 0x5FFFF), 

145 (0x6FFFE, 0x6FFFF), 

146 (0x7FFFE, 0x7FFFF), 

147 (0x8FFFE, 0x8FFFF), 

148 (0x9FFFE, 0x9FFFF), 

149 (0xAFFFE, 0xAFFFF), 

150 (0xBFFFE, 0xBFFFF), 

151 (0xCFFFE, 0xCFFFF), 

152 (0xDFFFE, 0xDFFFF), 

153 (0xEFFFE, 0xEFFFF), 

154 (0xFFFFE, 0xFFFFF), 

155 (0x10FFFE, 0x10FFFF), 

156] 

157 

158ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

159XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

160ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

161 

162 

163# from https://stackoverflow.com/a/35804945/5902284 

164def addLoggingLevel( 

165 levelName: str = "TRACE", levelNum: int = logging.DEBUG - 5, methodName=None 

166) -> None: 

167 """ 

168 Comprehensively adds a new logging level to the `logging` module and the 

169 currently configured logging class. 

170 

171 `levelName` becomes an attribute of the `logging` module with the value 

172 `levelNum`. `methodName` becomes a convenience method for both `logging` 

173 itself and the class returned by `logging.getLoggerClass()` (usually just 

174 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

175 used. 

176 

177 To avoid accidental clobberings of existing attributes, this method will 

178 raise an `AttributeError` if the level name is already an attribute of the 

179 `logging` module or if the method name is already present 

180 

181 Example 

182 ------- 

183 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

184 >>> logging.getLogger(__name__).setLevel("TRACE") 

185 >>> logging.getLogger(__name__).trace('that worked') 

186 >>> logging.trace('so did this') 

187 >>> logging.TRACE 

188 5 

189 

190 """ 

191 if not methodName: 

192 methodName = levelName.lower() 

193 

194 if hasattr(logging, levelName): 

195 log.debug(f"{levelName} already defined in logging module") 

196 return 

197 if hasattr(logging, methodName): 

198 log.debug(f"{methodName} already defined in logging module") 

199 return 

200 if hasattr(logging.getLoggerClass(), methodName): 

201 log.debug(f"{methodName} already defined in logger class") 

202 return 

203 

204 # This method was inspired by the answers to Stack Overflow post 

205 # http://stackoverflow.com/q/2183233/2988730, especially 

206 # http://stackoverflow.com/a/13638084/2988730 

207 def logForLevel(self, message, *args, **kwargs) -> None: 

208 if self.isEnabledFor(levelNum): 

209 self._log(levelNum, message, args, **kwargs) 

210 

211 def logToRoot(message, *args, **kwargs) -> None: 

212 logging.log(levelNum, message, *args, **kwargs) 

213 

214 logging.addLevelName(levelNum, levelName) 

215 setattr(logging, levelName, levelNum) 

216 setattr(logging.getLoggerClass(), methodName, logForLevel) 

217 setattr(logging, methodName, logToRoot) 

218 

219 

220class SlidgeLogger(logging.Logger): 

221 def trace(self) -> None: 

222 pass 

223 

224 

225log = logging.getLogger(__name__) 

226 

227 

228def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None: 

229 if len(resources) == 0: 

230 return None 

231 

232 if len(resources) == 1: 

233 return next(iter(resources.values())) 

234 

235 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

236 

237 if any(r["show"] == "" for r in resources.values()): 

238 # if a client is "available", we're "available" 

239 show = "" 

240 else: 

241 for r in by_priority: 

242 if r["show"]: 

243 show = r["show"] 

244 break 

245 else: 

246 raise RuntimeError() 

247 

248 # if there are different statuses, we use the highest priority one, 

249 # but we ignore resources without status, even with high priority 

250 status = "" 

251 for r in by_priority: 

252 if r["status"]: 

253 status = r["status"] 

254 break 

255 

256 return { 

257 "show": show, # type:ignore 

258 "status": status, 

259 "priority": 0, 

260 } 

261 

262 

263def remove_emoji_variation_selector_16(emoji: str): 

264 # this is required for compatibility with dino, and maybe other future clients? 

265 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

266 

267 

268def deprecated(name: str, new: Callable): 

269 # @functools.wraps 

270 def wrapped(*args, **kwargs): 

271 warnings.warn( 

272 f"{name} is deprecated. Use {new.__name__} instead", 

273 category=DeprecationWarning, 

274 ) 

275 return new(*args, **kwargs) 

276 

277 return wrapped 

278 

279 

280T = TypeVar("T", bound=NamedTuple) 

281 

282 

283def dict_to_named_tuple(data: dict, cls: type[T]) -> T: 

284 return cls(*(data.get(f) for f in cls._fields)) # type:ignore 

285 

286 

287def replace_mentions( 

288 text: str, 

289 mentions: list[Mention] | None, 

290 mapping: Callable[[Mention], str], 

291): 

292 if not mentions: 

293 return text 

294 

295 cursor = 0 

296 pieces = [] 

297 for mention in mentions: 

298 try: 

299 new_text = mapping(mention) 

300 except Exception as exc: 

301 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc) 

302 new_text = mapping(mention.contact) # type:ignore 

303 pieces.extend([text[cursor : mention.start], new_text]) 

304 cursor = mention.end 

305 pieces.append(text[cursor:]) 

306 return "".join(pieces) 

307 

308 

309def timeit(func): 

310 @wraps(func) 

311 async def wrapped(self, *args, **kwargs): 

312 start = time() 

313 r = await func(self, *args, **kwargs) 

314 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) 

315 return r 

316 

317 return wrapped 

318 

319 

320def strip_leading_emoji(text: str) -> str: 

321 if not EMOJI_LIB_AVAILABLE: 

322 return text 

323 words = text.split(" ") 

324 # is_emoji returns False for 🛷️ for obscure reasons, 

325 # purely_emoji seems better 

326 if len(words) > 1 and emoji.purely_emoji(words[0]): 

327 return " ".join(words[1:]) 

328 return text 

329 

330 

331async def noop_coro() -> None: 

332 pass 

333 

334 

335def add_quote_prefix(text: str): 

336 """ 

337 Return multi-line text with leading quote marks (i.e. the ">" character). 

338 """ 

339 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()