Coverage for slidge / util / util.py: 78%

170 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-06 15:18 +0000

1import logging 

2import mimetypes 

3import re 

4import warnings 

5from abc import ABCMeta 

6from functools import wraps 

7from pathlib import Path 

8from time import time 

9from typing import Any, Callable, NamedTuple, Optional, Type, TypeVar 

10 

11try: 

12 import emoji 

13except ImportError: 

14 EMOJI_LIB_AVAILABLE = False 

15else: 

16 EMOJI_LIB_AVAILABLE = True 

17 

18from slixmpp.types import ResourceDict 

19 

20from .types import Mention 

21 

22try: 

23 import magic 

24except ImportError as e: 

25 magic = None # type:ignore 

26 logging.warning( 

27 ( 

28 "Libmagic is not available: %s. " 

29 "It's OK if you don't use fix-filename-suffix-mime-type." 

30 ), 

31 e, 

32 ) 

33 

34 

35def fix_suffix( 

36 path: Path, mime_type: Optional[str], file_name: Optional[str] 

37) -> tuple[str, str]: 

38 guessed = magic.from_file(path, mime=True) 

39 if guessed == mime_type: 

40 log.debug("Magic and given MIME match") 

41 else: 

42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

43 mime_type = guessed 

44 

45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

46 

47 if file_name: 

48 name = Path(file_name) 

49 else: 

50 name = Path(path.name) 

51 

52 suffix = name.suffix 

53 

54 if suffix in valid_suffix_list: 

55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

56 return str(name), guessed 

57 

58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

59 if valid_suffix is None: 

60 log.debug("No valid suffix found") 

61 return str(name), guessed 

62 

63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

64 return str(name.with_suffix(valid_suffix)), guessed 

65 

66 

67class SubclassableOnce(type): 

68 # To allow importing everything, including plugins, during tests 

69 TEST_MODE: bool = False 

70 

71 def __init__( 

72 cls, 

73 name: str, 

74 bases: tuple[Type[Any], ...], 

75 dct: dict[str, Any], 

76 ) -> None: 

77 for b in bases: 

78 if type(b) in (SubclassableOnce, ABCSubclassableOnceAtMost): 

79 if hasattr(b, "_subclass") and not cls.TEST_MODE: 

80 raise RuntimeError( 

81 "This class must be subclassed once at most!", 

82 cls, 

83 name, 

84 bases, 

85 dct, 

86 ) 

87 else: 

88 log.debug("Setting %s as subclass for %s", cls, b) 

89 b._subclass = cls 

90 

91 super().__init__(name, bases, dct) 

92 

93 def get_self_or_unique_subclass(cls) -> "SubclassableOnce": 

94 try: 

95 return cls.get_unique_subclass() 

96 except AttributeError: 

97 return cls 

98 

99 def get_unique_subclass(cls) -> "SubclassableOnce": 

100 r = getattr(cls, "_subclass", None) 

101 if r is None: 

102 raise AttributeError("Could not find any subclass", cls) 

103 return r 

104 

105 def reset_subclass(cls) -> None: 

106 try: 

107 log.debug("Resetting subclass of %s", cls) 

108 delattr(cls, "_subclass") 

109 except AttributeError: 

110 log.debug("No subclass were registered for %s", cls) 

111 

112 

113class ABCSubclassableOnceAtMost(ABCMeta, SubclassableOnce): 

114 pass 

115 

116 

117def is_valid_phone_number(phone: Optional[str]): 

118 if phone is None: 

119 return False 

120 match = re.match(r"\+\d.*", phone) 

121 if match is None: 

122 return False 

123 return match[0] == phone 

124 

125 

126def strip_illegal_chars(s: str, repl: str = "") -> str: 

127 return ILLEGAL_XML_CHARS_RE.sub(repl, s) 

128 

129 

130# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

131ILLEGAL = [ 

132 (0x00, 0x08), 

133 (0x0B, 0x0C), 

134 (0x0E, 0x1F), 

135 (0x7F, 0x84), 

136 (0x86, 0x9F), 

137 (0xFDD0, 0xFDDF), 

138 (0xFFFE, 0xFFFF), 

139 (0x1FFFE, 0x1FFFF), 

140 (0x2FFFE, 0x2FFFF), 

141 (0x3FFFE, 0x3FFFF), 

142 (0x4FFFE, 0x4FFFF), 

143 (0x5FFFE, 0x5FFFF), 

144 (0x6FFFE, 0x6FFFF), 

145 (0x7FFFE, 0x7FFFF), 

146 (0x8FFFE, 0x8FFFF), 

147 (0x9FFFE, 0x9FFFF), 

148 (0xAFFFE, 0xAFFFF), 

149 (0xBFFFE, 0xBFFFF), 

150 (0xCFFFE, 0xCFFFF), 

151 (0xDFFFE, 0xDFFFF), 

152 (0xEFFFE, 0xEFFFF), 

153 (0xFFFFE, 0xFFFFF), 

154 (0x10FFFE, 0x10FFFF), 

155] 

156 

157ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

158XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

159ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

160 

161 

162# from https://stackoverflow.com/a/35804945/5902284 

163def addLoggingLevel( 

164 levelName: str = "TRACE", levelNum: int = logging.DEBUG - 5, methodName=None 

165) -> None: 

166 """ 

167 Comprehensively adds a new logging level to the `logging` module and the 

168 currently configured logging class. 

169 

170 `levelName` becomes an attribute of the `logging` module with the value 

171 `levelNum`. `methodName` becomes a convenience method for both `logging` 

172 itself and the class returned by `logging.getLoggerClass()` (usually just 

173 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

174 used. 

175 

176 To avoid accidental clobberings of existing attributes, this method will 

177 raise an `AttributeError` if the level name is already an attribute of the 

178 `logging` module or if the method name is already present 

179 

180 Example 

181 ------- 

182 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

183 >>> logging.getLogger(__name__).setLevel("TRACE") 

184 >>> logging.getLogger(__name__).trace('that worked') 

185 >>> logging.trace('so did this') 

186 >>> logging.TRACE 

187 5 

188 

189 """ 

190 if not methodName: 

191 methodName = levelName.lower() 

192 

193 if hasattr(logging, levelName): 

194 log.debug("{} already defined in logging module".format(levelName)) 

195 return 

196 if hasattr(logging, methodName): 

197 log.debug("{} already defined in logging module".format(methodName)) 

198 return 

199 if hasattr(logging.getLoggerClass(), methodName): 

200 log.debug("{} already defined in logger class".format(methodName)) 

201 return 

202 

203 # This method was inspired by the answers to Stack Overflow post 

204 # http://stackoverflow.com/q/2183233/2988730, especially 

205 # http://stackoverflow.com/a/13638084/2988730 

206 def logForLevel(self, message, *args, **kwargs) -> None: 

207 if self.isEnabledFor(levelNum): 

208 self._log(levelNum, message, args, **kwargs) 

209 

210 def logToRoot(message, *args, **kwargs) -> None: 

211 logging.log(levelNum, message, *args, **kwargs) 

212 

213 logging.addLevelName(levelNum, levelName) 

214 setattr(logging, levelName, levelNum) 

215 setattr(logging.getLoggerClass(), methodName, logForLevel) 

216 setattr(logging, methodName, logToRoot) 

217 

218 

219class SlidgeLogger(logging.Logger): 

220 def trace(self) -> None: 

221 pass 

222 

223 

224log = logging.getLogger(__name__) 

225 

226 

227def merge_resources(resources: dict[str, ResourceDict]) -> Optional[ResourceDict]: 

228 if len(resources) == 0: 

229 return None 

230 

231 if len(resources) == 1: 

232 return next(iter(resources.values())) 

233 

234 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

235 

236 if any(r["show"] == "" for r in resources.values()): 

237 # if a client is "available", we're "available" 

238 show = "" 

239 else: 

240 for r in by_priority: 

241 if r["show"]: 

242 show = r["show"] 

243 break 

244 else: 

245 raise RuntimeError() 

246 

247 # if there are different statuses, we use the highest priority one, 

248 # but we ignore resources without status, even with high priority 

249 status = "" 

250 for r in by_priority: 

251 if r["status"]: 

252 status = r["status"] 

253 break 

254 

255 return { 

256 "show": show, # type:ignore 

257 "status": status, 

258 "priority": 0, 

259 } 

260 

261 

262def remove_emoji_variation_selector_16(emoji: str): 

263 # this is required for compatibility with dino, and maybe other future clients? 

264 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

265 

266 

267def deprecated(name: str, new: Callable): 

268 # @functools.wraps 

269 def wrapped(*args, **kwargs): 

270 warnings.warn( 

271 f"{name} is deprecated. Use {new.__name__} instead", 

272 category=DeprecationWarning, 

273 ) 

274 return new(*args, **kwargs) 

275 

276 return wrapped 

277 

278 

279T = TypeVar("T", bound=NamedTuple) 

280 

281 

282def dict_to_named_tuple(data: dict, cls: Type[T]) -> T: 

283 return cls(*(data.get(f) for f in cls._fields)) # type:ignore 

284 

285 

286def replace_mentions( 

287 text: str, 

288 mentions: Optional[list[Mention]], 

289 mapping: Callable[[Mention], str], 

290): 

291 if not mentions: 

292 return text 

293 

294 cursor = 0 

295 pieces = [] 

296 for mention in mentions: 

297 try: 

298 new_text = mapping(mention) 

299 except Exception as exc: 

300 log.debug("Attempting slidge <= 0.3.3 compatibility: %s", exc) 

301 new_text = mapping(mention.contact) # type:ignore 

302 pieces.extend([text[cursor : mention.start], new_text]) 

303 cursor = mention.end 

304 pieces.append(text[cursor:]) 

305 return "".join(pieces) 

306 

307 

308def timeit(func): 

309 @wraps(func) 

310 async def wrapped(self, *args, **kwargs): 

311 start = time() 

312 r = await func(self, *args, **kwargs) 

313 self.log.debug("%s took %s ms", func.__name__, round((time() - start) * 1000)) 

314 return r 

315 

316 return wrapped 

317 

318 

319def strip_leading_emoji(text: str) -> str: 

320 if not EMOJI_LIB_AVAILABLE: 

321 return text 

322 words = text.split(" ") 

323 # is_emoji returns False for 🛷️ for obscure reasons, 

324 # purely_emoji seems better 

325 if len(words) > 1 and emoji.purely_emoji(words[0]): 

326 return " ".join(words[1:]) 

327 return text 

328 

329 

330async def noop_coro() -> None: 

331 pass 

332 

333 

334def add_quote_prefix(text: str): 

335 """ 

336 Return multi-line text with leading quote marks (i.e. the ">" character). 

337 """ 

338 return "\n".join(("> " + x).strip() for x in text.split("\n")).strip()