Coverage for slidge / util / util.py: 78%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-03-13 22:59 +0000

1import logging 

2import mimetypes 

3import re 

4import warnings 

5from collections.abc import Callable 

6from functools import wraps 

7from pathlib import Path 

8from time import time 

9from typing import Any, NamedTuple, TypeVar 

10 

11try: 

12 import emoji 

13except ImportError: 

14 EMOJI_LIB_AVAILABLE = False 

15else: 

16 EMOJI_LIB_AVAILABLE = True 

17 

18from slixmpp.types import ResourceDict 

19 

20from .types import Mention 

21 

22try: 

23 import magic 

24except ImportError as e: 

25 magic = None # type:ignore 

26 logging.warning( 

27 ( 

28 "Libmagic is not available: %s. " 

29 "It's OK if you don't use fix-filename-suffix-mime-type." 

30 ), 

31 e, 

32 ) 

33 

34 

35def fix_suffix( 

36 path: Path, mime_type: str | None, file_name: str | None 

37) -> tuple[str, str]: 

38 guessed = magic.from_file(path, mime=True) 

39 if guessed == mime_type: 

40 log.debug("Magic and given MIME match") 

41 else: 

42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

43 mime_type = guessed 

44 

45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

46 

47 if file_name: 

48 name = Path(file_name) 

49 else: 

50 name = Path(path.name) 

51 

52 suffix = name.suffix 

53 

54 if suffix in valid_suffix_list: 

55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

56 return str(name), guessed 

57 

58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

59 if valid_suffix is None: 

60 log.debug("No valid suffix found") 

61 return str(name), guessed 

62 

63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

64 return str(name.with_suffix(valid_suffix)), guessed 

65 

66 

67class SubclassableOnce: 

68 # To allow importing everything, including plugins, during tests 

69 TEST_MODE: bool = False 

70 __subclasses: dict[type["SubclassableOnce"], type["SubclassableOnce"] | None] = {} 

71 

72 def __init_subclass__(cls, **kwargs: Any) -> None: 

73 if SubclassableOnce not in cls.__bases__: 

74 base = SubclassableOnce.__find_direct_child(cls) 

75 existing = SubclassableOnce.__subclasses.get(base) 

76 if existing is not None and not SubclassableOnce.TEST_MODE: 

77 raise RuntimeError("This class must be subclassed once at most!") 

78 cls.__subclasses[base] = cls 

79 super().__init_subclass__(**kwargs) 

80 

81 @staticmethod 

82 def __find_direct_child(cls: type["SubclassableOnce"]) -> type["SubclassableOnce"]: 

83 for base in cls.__bases__: 

84 if issubclass(base, SubclassableOnce): 

85 return base 

86 else: 

87 raise RuntimeError("wut") 

88 

89 @classmethod 

90 def get_self_or_unique_subclass(cls) -> "type[SubclassableOnce]": 

91 try: 

92 return cls.get_unique_subclass() 

93 except AttributeError: 

94 return cls 

95 

96 @classmethod 

97 def get_unique_subclass(cls) -> "type[SubclassableOnce]": 

98 existing = SubclassableOnce.__subclasses.get(cls) 

99 if existing is None: 

100 raise AttributeError("Could not find any subclass", cls) 

101 return existing 

102 

103 @classmethod 

104 def reset_subclass(cls) -> None: 

105 log.debug("Resetting subclass of %s", cls) 

106 cls.__subclasses[cls] = None 

107 

108 

109def is_valid_phone_number(phone: str | None): 

110 if phone is None: 

111 return False 

112 match = re.match(r"\+\d.*", phone) 

113 if match is None: 

114 return False 

115 return match[0] == phone 

116 

117 

118def strip_illegal_chars(s: str, repl: str = "") -> str: 

119 return ILLEGAL_XML_CHARS_RE.sub(repl, s) 

120 

121 

122# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

123ILLEGAL = [ 

124 (0x00, 0x08), 

125 (0x0B, 0x0C), 

126 (0x0E, 0x1F), 

127 (0x7F, 0x84), 

128 (0x86, 0x9F), 

129 (0xFDD0, 0xFDDF), 

130 (0xFFFE, 0xFFFF), 

131 (0x1FFFE, 0x1FFFF), 

132 (0x2FFFE, 0x2FFFF), 

133 (0x3FFFE, 0x3FFFF), 

134 (0x4FFFE, 0x4FFFF), 

135 (0x5FFFE, 0x5FFFF), 

136 (0x6FFFE, 0x6FFFF), 

137 (0x7FFFE, 0x7FFFF), 

138 (0x8FFFE, 0x8FFFF), 

139 (0x9FFFE, 0x9FFFF), 

140 (0xAFFFE, 0xAFFFF), 

141 (0xBFFFE, 0xBFFFF), 

142 (0xCFFFE, 0xCFFFF), 

143 (0xDFFFE, 0xDFFFF), 

144 (0xEFFFE, 0xEFFFF), 

145 (0xFFFFE, 0xFFFFF), 

146 (0x10FFFE, 0x10FFFF), 

147] 

148 

149ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

150XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

151ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

152 

153 

154# from https://stackoverflow.com/a/35804945/5902284 

155def addLoggingLevel( 

156 levelName: str = "TRACE", levelNum: int = logging.DEBUG - 5, methodName=None 

157) -> None: 

158 """ 

159 Comprehensively adds a new logging level to the `logging` module and the 

160 currently configured logging class. 

161 

162 `levelName` becomes an attribute of the `logging` module with the value 

163 `levelNum`. `methodName` becomes a convenience method for both `logging` 

164 itself and the class returned by `logging.getLoggerClass()` (usually just 

165 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

166 used. 

167 

168 To avoid accidental clobberings of existing attributes, this method will 

169 raise an `AttributeError` if the level name is already an attribute of the 

170 `logging` module or if the method name is already present 

171 

172 Example 

173 ------- 

174 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

175 >>> logging.getLogger(__name__).setLevel("TRACE") 

176 >>> logging.getLogger(__name__).trace('that worked') 

177 >>> logging.trace('so did this') 

178 >>> logging.TRACE 

179 5 

180 

181 """ 

182 if not methodName: 

183 methodName = levelName.lower() 

184 

185 if hasattr(logging, levelName): 

186 log.debug(f"{levelName} already defined in logging module") 

187 return 

188 if hasattr(logging, methodName): 

189 log.debug(f"{methodName} already defined in logging module") 

190 return 

191 if hasattr(logging.getLoggerClass(), methodName): 

192 log.debug(f"{methodName} already defined in logger class") 

193 return 

194 

195 # This method was inspired by the answers to Stack Overflow post 

196 # http://stackoverflow.com/q/2183233/2988730, especially 

197 # http://stackoverflow.com/a/13638084/2988730 

198 def logForLevel(self, message, *args, **kwargs) -> None: 

199 if self.isEnabledFor(levelNum): 

200 self._log(levelNum, message, args, **kwargs) 

201 

202 def logToRoot(message, *args, **kwargs) -> None: 

203 logging.log(levelNum, message, *args, **kwargs) 

204 

205 logging.addLevelName(levelNum, levelName) 

206 setattr(logging, levelName, levelNum) 

207 setattr(logging.getLoggerClass(), methodName, logForLevel) 

208 setattr(logging, methodName, logToRoot) 

209 

210 

211class SlidgeLogger(logging.Logger): 

212 def trace(self) -> None: 

213 pass 

214 

215 

216log = logging.getLogger(__name__) 

217 

218 

219def merge_resources(resources: dict[str, ResourceDict]) -> ResourceDict | None: 

220 if len(resources) == 0: 

221 return None 

222 

223 if len(resources) == 1: 

224 return next(iter(resources.values())) 

225 

226 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

227 

228 if any(r["show"] == "" for r in resources.values()): 

229 # if a client is "available", we're "available" 

230 show = "" 

231 else: 

232 for r in by_priority: 

233 if r["show"]: 

234 show = r["show"] 

235 break 

236 else: 

237 raise RuntimeError() 

238 

239 # if there are different statuses, we use the highest priority one, 

240 # but we ignore resources without status, even with high priority 

241 status = "" 

242 for r in by_priority: 

243 if r["status"]: 

244 status = r["status"] 

245 break 

246 

247 return { 

248 "show": show, # type:ignore 

249 "status": status, 

250 "priority": 0, 

251 } 

252 

253 

254def remove_emoji_variation_selector_16(emoji: str): 

255 # this is required for compatibility with dino, and maybe other future clients? 

256 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

257 

258 

259def deprecated(name: str, new: Callable): 

260 # @functools.wraps 

261 def wrapped(*args, **kwargs): 

262 warnings.warn( 

263 f"{name} is deprecated. Use {new.__name__} instead", 

264 category=DeprecationWarning, 

265 ) 

266 return new(*args, **kwargs) 

267 

268 return wrapped 

269 

270 

271T = TypeVar("T", bound=NamedTuple) 

272 

273 

274def dict_to_named_tuple(data: dict, cls: type[T]) -> T: 

275 return cls(*(data.get(f) for f in cls._fields)) # type:ignore 

276 

277 

278def replace_mentions( 

279 text: str, 

280 mentions: list[Mention] | None, 

281 mapping: Callable[[Mention], str], 

282): 

283 if not mentions: 

284 return text 

285 

286 cursor = 0 

287 pieces = [] 

288 for mention in mentions: 

289 try: 

290 new_text = mapping(mention) 

291 except Exception as exc: 

292 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc) 

293 new_text = mapping(mention.contact) # type:ignore 

294 pieces.extend([text[cursor : mention.start], new_text]) 

295 cursor = mention.end 

296 pieces.append(text[cursor:]) 

297 return "".join(pieces) 

298 

299 

300def timeit(func): 

301 @wraps(func) 

302 async def wrapped(self, *args, **kwargs): 

303 start = time() 

304 r = await func(self, *args, **kwargs) 

305 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) 

306 return r 

307 

308 return wrapped 

309 

310 

311def strip_leading_emoji(text: str) -> str: 

312 if not EMOJI_LIB_AVAILABLE: 

313 return text 

314 words = text.split(" ") 

315 # is_emoji returns False for 🛷️ for obscure reasons, 

316 # purely_emoji seems better 

317 if len(words) > 1 and emoji.purely_emoji(words[0]): 

318 return " ".join(words[1:]) 

319 return text 

320 

321 

322async def noop_coro() -> None: 

323 pass 

324 

325 

326def add_quote_prefix(text: str): 

327 """ 

328 Return multi-line text with leading quote marks (i.e. the ">" character). 

329 """ 

330 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()