Coverage for slidge/util/util.py: 74%

174 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import logging 

2import mimetypes 

3import re 

4import subprocess 

5import warnings 

6from abc import ABCMeta 

7from functools import wraps 

8from pathlib import Path 

9from time import time 

10from typing import TYPE_CHECKING, Callable, NamedTuple, Optional, Type, TypeVar 

11 

12try: 

13 import emoji 

14except ImportError: 

15 EMOJI_LIB_AVAILABLE = False 

16else: 

17 EMOJI_LIB_AVAILABLE = True 

18 

19from .types import Mention, ResourceDict 

20 

21if TYPE_CHECKING: 

22 from ..contact.contact import LegacyContact 

23 

24try: 

25 import magic 

26except ImportError as e: 

27 magic = None # type:ignore 

28 logging.warning( 

29 ( 

30 "Libmagic is not available: %s. " 

31 "It's OK if you don't use fix-filename-suffix-mime-type." 

32 ), 

33 e, 

34 ) 

35 

36 

37def fix_suffix(path: Path, mime_type: Optional[str], file_name: Optional[str]): 

38 guessed = magic.from_file(path, mime=True) 

39 if guessed == mime_type: 

40 log.debug("Magic and given MIME match") 

41 else: 

42 log.debug("Magic (%s) and given MIME (%s) differ", guessed, mime_type) 

43 mime_type = guessed 

44 

45 valid_suffix_list = mimetypes.guess_all_extensions(mime_type, strict=False) 

46 

47 if file_name: 

48 name = Path(file_name) 

49 else: 

50 name = Path(path.name) 

51 

52 suffix = name.suffix 

53 

54 if suffix in valid_suffix_list: 

55 log.debug("Suffix %s is in %s", suffix, valid_suffix_list) 

56 return name 

57 

58 valid_suffix = mimetypes.guess_extension(mime_type.split(";")[0], strict=False) 

59 if valid_suffix is None: 

60 log.debug("No valid suffix found") 

61 return name 

62 

63 log.debug("Changing suffix of %s to %s", file_name or path.name, valid_suffix) 

64 return name.with_suffix(valid_suffix) 

65 

66 

67class SubclassableOnce(type): 

68 TEST_MODE = False # To allow importing everything, including plugins, during tests 

69 

70 def __init__(cls, name, bases, dct): 

71 for b in bases: 

72 if type(b) in (SubclassableOnce, ABCSubclassableOnceAtMost): 

73 if hasattr(b, "_subclass") and not cls.TEST_MODE: 

74 raise RuntimeError( 

75 "This class must be subclassed once at most!", 

76 cls, 

77 name, 

78 bases, 

79 dct, 

80 ) 

81 else: 

82 log.debug("Setting %s as subclass for %s", cls, b) 

83 b._subclass = cls 

84 

85 super().__init__(name, bases, dct) 

86 

87 def get_self_or_unique_subclass(cls): 

88 try: 

89 return cls.get_unique_subclass() 

90 except AttributeError: 

91 return cls 

92 

93 def get_unique_subclass(cls): 

94 r = getattr(cls, "_subclass", None) 

95 if r is None: 

96 raise AttributeError("Could not find any subclass", cls) 

97 return r 

98 

99 def reset_subclass(cls): 

100 try: 

101 log.debug("Resetting subclass of %s", cls) 

102 delattr(cls, "_subclass") 

103 except AttributeError: 

104 log.debug("No subclass were registered for %s", cls) 

105 

106 

107class ABCSubclassableOnceAtMost(ABCMeta, SubclassableOnce): 

108 pass 

109 

110 

111def is_valid_phone_number(phone: Optional[str]): 

112 if phone is None: 

113 return False 

114 match = re.match(r"\+\d.*", phone) 

115 if match is None: 

116 return False 

117 return match[0] == phone 

118 

119 

120def strip_illegal_chars(s: str): 

121 return ILLEGAL_XML_CHARS_RE.sub("", s) 

122 

123 

124# from https://stackoverflow.com/a/64570125/5902284 and Link Mauve 

125ILLEGAL = [ 

126 (0x00, 0x08), 

127 (0x0B, 0x0C), 

128 (0x0E, 0x1F), 

129 (0x7F, 0x84), 

130 (0x86, 0x9F), 

131 (0xFDD0, 0xFDDF), 

132 (0xFFFE, 0xFFFF), 

133 (0x1FFFE, 0x1FFFF), 

134 (0x2FFFE, 0x2FFFF), 

135 (0x3FFFE, 0x3FFFF), 

136 (0x4FFFE, 0x4FFFF), 

137 (0x5FFFE, 0x5FFFF), 

138 (0x6FFFE, 0x6FFFF), 

139 (0x7FFFE, 0x7FFFF), 

140 (0x8FFFE, 0x8FFFF), 

141 (0x9FFFE, 0x9FFFF), 

142 (0xAFFFE, 0xAFFFF), 

143 (0xBFFFE, 0xBFFFF), 

144 (0xCFFFE, 0xCFFFF), 

145 (0xDFFFE, 0xDFFFF), 

146 (0xEFFFE, 0xEFFFF), 

147 (0xFFFFE, 0xFFFFF), 

148 (0x10FFFE, 0x10FFFF), 

149] 

150 

151ILLEGAL_RANGES = [rf"{chr(low)}-{chr(high)}" for (low, high) in ILLEGAL] 

152XML_ILLEGAL_CHARACTER_REGEX = "[" + "".join(ILLEGAL_RANGES) + "]" 

153ILLEGAL_XML_CHARS_RE = re.compile(XML_ILLEGAL_CHARACTER_REGEX) 

154 

155 

156# from https://stackoverflow.com/a/35804945/5902284 

157def addLoggingLevel( 

158 levelName: str = "TRACE", levelNum: int = logging.DEBUG - 5, methodName=None 

159): 

160 """ 

161 Comprehensively adds a new logging level to the `logging` module and the 

162 currently configured logging class. 

163 

164 `levelName` becomes an attribute of the `logging` module with the value 

165 `levelNum`. `methodName` becomes a convenience method for both `logging` 

166 itself and the class returned by `logging.getLoggerClass()` (usually just 

167 `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is 

168 used. 

169 

170 To avoid accidental clobberings of existing attributes, this method will 

171 raise an `AttributeError` if the level name is already an attribute of the 

172 `logging` module or if the method name is already present 

173 

174 Example 

175 ------- 

176 >>> addLoggingLevel('TRACE', logging.DEBUG - 5) 

177 >>> logging.getLogger(__name__).setLevel("TRACE") 

178 >>> logging.getLogger(__name__).trace('that worked') 

179 >>> logging.trace('so did this') 

180 >>> logging.TRACE 

181 5 

182 

183 """ 

184 if not methodName: 

185 methodName = levelName.lower() 

186 

187 if hasattr(logging, levelName): 

188 log.debug("{} already defined in logging module".format(levelName)) 

189 return 

190 if hasattr(logging, methodName): 

191 log.debug("{} already defined in logging module".format(methodName)) 

192 return 

193 if hasattr(logging.getLoggerClass(), methodName): 

194 log.debug("{} already defined in logger class".format(methodName)) 

195 return 

196 

197 # This method was inspired by the answers to Stack Overflow post 

198 # http://stackoverflow.com/q/2183233/2988730, especially 

199 # http://stackoverflow.com/a/13638084/2988730 

200 def logForLevel(self, message, *args, **kwargs): 

201 if self.isEnabledFor(levelNum): 

202 self._log(levelNum, message, args, **kwargs) 

203 

204 def logToRoot(message, *args, **kwargs): 

205 logging.log(levelNum, message, *args, **kwargs) 

206 

207 logging.addLevelName(levelNum, levelName) 

208 setattr(logging, levelName, levelNum) 

209 setattr(logging.getLoggerClass(), methodName, logForLevel) 

210 setattr(logging, methodName, logToRoot) 

211 

212 

213class SlidgeLogger(logging.Logger): 

214 def trace(self): 

215 pass 

216 

217 

218log = logging.getLogger(__name__) 

219 

220 

221def get_version(): 

222 try: 

223 git = subprocess.check_output( 

224 ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL 

225 ).decode() 

226 except (FileNotFoundError, subprocess.CalledProcessError): 

227 pass 

228 else: 

229 return "git-" + git[:10] 

230 

231 return "NO_VERSION" 

232 

233 

234def merge_resources(resources: dict[str, ResourceDict]) -> Optional[ResourceDict]: 

235 if len(resources) == 0: 

236 return None 

237 

238 if len(resources) == 1: 

239 return next(iter(resources.values())) 

240 

241 by_priority = sorted(resources.values(), key=lambda r: r["priority"], reverse=True) 

242 

243 if any(r["show"] == "" for r in resources.values()): 

244 # if a client is "available", we're "available" 

245 show = "" 

246 else: 

247 for r in by_priority: 

248 if r["show"]: 

249 show = r["show"] 

250 break 

251 else: 

252 raise RuntimeError() 

253 

254 # if there are different statuses, we use the highest priority one, 

255 # but we ignore resources without status, even with high priority 

256 status = "" 

257 for r in by_priority: 

258 if r["status"]: 

259 status = r["status"] 

260 break 

261 

262 return { 

263 "show": show, # type:ignore 

264 "status": status, 

265 "priority": 0, 

266 } 

267 

268 

269def remove_emoji_variation_selector_16(emoji: str): 

270 # this is required for compatibility with dino, and maybe other future clients? 

271 return bytes(emoji, encoding="utf-8").replace(b"\xef\xb8\x8f", b"").decode() 

272 

273 

274def deprecated(name: str, new: Callable): 

275 # @functools.wraps 

276 def wrapped(*args, **kwargs): 

277 warnings.warn( 

278 f"{name} is deprecated. Use {new.__name__} instead", 

279 category=DeprecationWarning, 

280 ) 

281 return new(*args, **kwargs) 

282 

283 return wrapped 

284 

285 

286T = TypeVar("T", bound=NamedTuple) 

287 

288 

289def dict_to_named_tuple(data: dict, cls: Type[T]) -> T: 

290 return cls(*(data.get(f) for f in cls._fields)) # type:ignore 

291 

292 

293def replace_mentions( 

294 text: str, 

295 mentions: Optional[list[Mention]], 

296 mapping: Callable[["LegacyContact"], str], 

297): 

298 if not mentions: 

299 return text 

300 

301 cursor = 0 

302 pieces = [] 

303 for mention in mentions: 

304 pieces.extend([text[cursor : mention.start], mapping(mention.contact)]) 

305 cursor = mention.end 

306 pieces.append(text[cursor:]) 

307 return "".join(pieces) 

308 

309 

310def with_session(func): 

311 @wraps(func) 

312 async def wrapped(self, *args, **kwargs): 

313 with self.xmpp.store.session(): 

314 return await func(self, *args, **kwargs) 

315 

316 return wrapped 

317 

318 

319def timeit(func): 

320 @wraps(func) 

321 async def wrapped(self, *args, **kwargs): 

322 start = time() 

323 r = await func(self, *args, **kwargs) 

324 self.log.info("%s took %s ms", func.__name__, round((time() - start) * 1000)) 

325 return r 

326 

327 return wrapped 

328 

329 

330def strip_leading_emoji(text: str) -> str: 

331 if not EMOJI_LIB_AVAILABLE: 

332 return text 

333 words = text.split(" ") 

334 # is_emoji returns False for 🛷️ for obscure reasons, 

335 # purely_emoji seems better 

336 if len(words) > 1 and emoji.purely_emoji(words[0]): 

337 return " ".join(words[1:]) 

338 return text