Coverage for slidge/core/mixins/attachment.py: 82%

270 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import base64 

2import functools 

3import logging 

4import os 

5import re 

6import shutil 

7import stat 

8import tempfile 

9import warnings 

10from datetime import datetime 

11from itertools import chain 

12from mimetypes import guess_extension, guess_type 

13from pathlib import Path 

14from typing import IO, AsyncIterator, Collection, Optional, Sequence, Union 

15from urllib.parse import quote as urlquote 

16from uuid import uuid4 

17from xml.etree import ElementTree as ET 

18 

19import thumbhash 

20from PIL import Image, ImageOps 

21from slixmpp import JID, Message 

22from slixmpp.exceptions import IqError, IqTimeout 

23from slixmpp.plugins.xep_0363 import FileUploadError 

24from slixmpp.plugins.xep_0447.stanza import StatelessFileSharing 

25 

26from ...db.avatar import avatar_cache 

27from ...slixfix.xep_0264.stanza import Thumbnail 

28from ...util.types import ( 

29 LegacyAttachment, 

30 LegacyMessageType, 

31 LegacyThreadType, 

32 MessageReference, 

33) 

34from ...util.util import fix_suffix 

35from .. import config 

36from .message_text import TextMessageMixin 

37 

38 

39class AttachmentMixin(TextMessageMixin): 

40 def __init__(self, *a, **kw): 

41 super().__init__(*a, **kw) 

42 self.__store = self.xmpp.store.attachments 

43 

44 async def __upload( 

45 self, 

46 file_path: Path, 

47 file_name: Optional[str] = None, 

48 content_type: Optional[str] = None, 

49 ): 

50 if file_name and file_path.name != file_name: 

51 d = Path(tempfile.mkdtemp()) 

52 temp = d / file_name 

53 temp.symlink_to(file_path) 

54 file_path = temp 

55 else: 

56 d = None 

57 if config.UPLOAD_SERVICE: 

58 domain = None 

59 else: 

60 domain = re.sub(r"^.*?\.", "", self.xmpp.boundjid.bare) 

61 try: 

62 new_url = await self.xmpp.plugin["xep_0363"].upload_file( 

63 filename=file_path, 

64 content_type=content_type, 

65 ifrom=config.UPLOAD_REQUESTER or self.xmpp.boundjid, 

66 domain=JID(domain), 

67 ) 

68 except (FileUploadError, IqError, IqTimeout) as e: 

69 warnings.warn(f"Something is wrong with the upload service: {e!r}") 

70 return None 

71 finally: 

72 if d is not None: 

73 file_path.unlink() 

74 d.rmdir() 

75 

76 return new_url 

77 

78 @staticmethod 

79 async def __no_upload( 

80 file_path: Path, 

81 file_name: Optional[str] = None, 

82 legacy_file_id: Optional[Union[str, int]] = None, 

83 ): 

84 file_id = str(uuid4()) if legacy_file_id is None else str(legacy_file_id) 

85 assert config.NO_UPLOAD_PATH is not None 

86 assert config.NO_UPLOAD_URL_PREFIX is not None 

87 destination_dir = Path(config.NO_UPLOAD_PATH) / file_id 

88 

89 if destination_dir.exists(): 

90 log.debug("Dest dir exists: %s", destination_dir) 

91 files = list(f for f in destination_dir.glob("**/*") if f.is_file()) 

92 if len(files) == 1: 

93 log.debug( 

94 "Found the legacy attachment '%s' at '%s'", 

95 legacy_file_id, 

96 files[0], 

97 ) 

98 name = files[0].name 

99 uu = files[0].parent.name # anti-obvious url trick, see below 

100 return files[0], "/".join([file_id, uu, name]) 

101 else: 

102 log.warning( 

103 ( 

104 "There are several or zero files in %s, " 

105 "slidge doesn't know which one to pick among %s. " 

106 "Removing the dir." 

107 ), 

108 destination_dir, 

109 files, 

110 ) 

111 shutil.rmtree(destination_dir) 

112 

113 log.debug("Did not find a file in: %s", destination_dir) 

114 # let's use a UUID to avoid URLs being too obvious 

115 uu = str(uuid4()) 

116 destination_dir = destination_dir / uu 

117 destination_dir.mkdir(parents=True) 

118 

119 name = file_name or file_path.name 

120 destination = destination_dir / name 

121 method = config.NO_UPLOAD_METHOD 

122 if method == "copy": 

123 shutil.copy2(file_path, destination) 

124 elif method == "hardlink": 

125 os.link(file_path, destination) 

126 elif method == "symlink": 

127 os.symlink(file_path, destination, target_is_directory=True) 

128 elif method == "move": 

129 shutil.move(file_path, destination) 

130 else: 

131 raise RuntimeError("No upload method not recognized", method) 

132 

133 if config.NO_UPLOAD_FILE_READ_OTHERS: 

134 log.debug("Changing perms of %s", destination) 

135 destination.chmod(destination.stat().st_mode | stat.S_IROTH) 

136 uploaded_url = "/".join([file_id, uu, name]) 

137 

138 return destination, uploaded_url 

139 

140 async def __get_url( 

141 self, 

142 file_path: Optional[Path] = None, 

143 async_data_stream: Optional[AsyncIterator[bytes]] = None, 

144 data_stream: Optional[IO[bytes]] = None, 

145 data: Optional[bytes] = None, 

146 file_url: Optional[str] = None, 

147 file_name: Optional[str] = None, 

148 content_type: Optional[str] = None, 

149 legacy_file_id: Optional[Union[str, int]] = None, 

150 ) -> tuple[bool, Optional[Path], str]: 

151 if legacy_file_id: 

152 cache = self.__store.get_url(str(legacy_file_id)) 

153 if cache is not None: 

154 async with self.session.http.head(cache) as r: 

155 if r.status < 400: 

156 return False, None, cache 

157 else: 

158 self.__store.remove(str(legacy_file_id)) 

159 

160 if file_url and config.USE_ATTACHMENT_ORIGINAL_URLS: 

161 return False, None, file_url 

162 

163 if file_name and len(file_name) > config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH: 

164 log.debug("Trimming long filename: %s", file_name) 

165 base, ext = os.path.splitext(file_name) 

166 file_name = ( 

167 base[: config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH - len(ext)] + ext 

168 ) 

169 

170 if file_path is None: 

171 if file_name is None: 

172 file_name = str(uuid4()) 

173 if content_type is not None: 

174 ext = guess_extension(content_type, strict=False) # type:ignore 

175 if ext is not None: 

176 file_name += ext 

177 temp_dir = Path(tempfile.mkdtemp()) 

178 file_path = temp_dir / file_name 

179 if file_url: 

180 async with self.session.http.get(file_url) as r: 

181 with file_path.open("wb") as f: 

182 f.write(await r.read()) 

183 

184 elif data_stream is not None: 

185 data = data_stream.read() 

186 if data is None: 

187 raise RuntimeError 

188 

189 with file_path.open("wb") as f: 

190 f.write(data) 

191 elif async_data_stream is not None: 

192 # TODO: patch slixmpp to allow this as data source for 

193 # upload_file() so we don't even have to write anything 

194 # to disk. 

195 with file_path.open("wb") as f: 

196 async for chunk in async_data_stream: 

197 f.write(chunk) 

198 elif data is not None: 

199 with file_path.open("wb") as f: 

200 f.write(data) 

201 

202 is_temp = not bool(config.NO_UPLOAD_PATH) 

203 else: 

204 is_temp = False 

205 

206 if config.FIX_FILENAME_SUFFIX_MIME_TYPE: 

207 file_name = str(fix_suffix(file_path, content_type, file_name)) 

208 

209 if config.NO_UPLOAD_PATH: 

210 local_path, new_url = await self.__no_upload( 

211 file_path, file_name, legacy_file_id 

212 ) 

213 new_url = (config.NO_UPLOAD_URL_PREFIX or "") + "/" + urlquote(new_url) 

214 else: 

215 local_path = file_path 

216 new_url = await self.__upload(file_path, file_name, content_type) 

217 if legacy_file_id: 

218 self.__store.set_url(self.session.user_pk, str(legacy_file_id), new_url) 

219 

220 return is_temp, local_path, new_url 

221 

222 async def __set_sims( 

223 self, 

224 msg: Message, 

225 uploaded_url: str, 

226 path: Optional[Path], 

227 content_type: Optional[str] = None, 

228 caption: Optional[str] = None, 

229 file_name: Optional[str] = None, 

230 ) -> Thumbnail | None: 

231 cache = self.__store.get_sims(uploaded_url) 

232 if cache: 

233 ref = self.xmpp["xep_0372"].stanza.Reference(xml=ET.fromstring(cache)) 

234 msg.append(ref) 

235 if ref["sims"]["file"].get_plugin("thumbnail", check=True): 

236 return ref["sims"]["file"]["thumbnail"] 

237 else: 

238 return None 

239 

240 if not path: 

241 return None 

242 

243 ref = self.xmpp["xep_0385"].get_sims( 

244 path, [uploaded_url], content_type, caption 

245 ) 

246 if file_name: 

247 ref["sims"]["file"]["name"] = file_name 

248 thumbnail = None 

249 if content_type is not None and content_type.startswith("image"): 

250 try: 

251 h, x, y = await self.xmpp.loop.run_in_executor( 

252 avatar_cache._thread_pool, get_thumbhash, path 

253 ) 

254 except Exception as e: 

255 log.debug("Could not generate a thumbhash", exc_info=e) 

256 else: 

257 thumbnail = ref["sims"]["file"]["thumbnail"] 

258 thumbnail["width"] = x 

259 thumbnail["height"] = y 

260 thumbnail["media-type"] = "image/thumbhash" 

261 thumbnail["uri"] = "data:image/thumbhash;base64," + urlquote(h) 

262 

263 self.__store.set_sims(uploaded_url, str(ref)) 

264 

265 msg.append(ref) 

266 

267 return thumbnail 

268 

269 def __set_sfs( 

270 self, 

271 msg: Message, 

272 uploaded_url: str, 

273 path: Optional[Path], 

274 content_type: Optional[str] = None, 

275 caption: Optional[str] = None, 

276 file_name: Optional[str] = None, 

277 thumbnail: Optional[Thumbnail] = None, 

278 ): 

279 cache = self.__store.get_sfs(uploaded_url) 

280 if cache: 

281 msg.append(StatelessFileSharing(xml=ET.fromstring(cache))) 

282 return 

283 

284 if not path: 

285 return 

286 

287 sfs = self.xmpp["xep_0447"].get_sfs(path, [uploaded_url], content_type, caption) 

288 if file_name: 

289 sfs["file"]["name"] = file_name 

290 if thumbnail is not None: 

291 sfs["file"].append(thumbnail) 

292 self.__store.set_sfs(uploaded_url, str(sfs)) 

293 

294 msg.append(sfs) 

295 

296 def __send_url( 

297 self, 

298 msg: Message, 

299 legacy_msg_id: LegacyMessageType, 

300 uploaded_url: str, 

301 caption: Optional[str] = None, 

302 carbon=False, 

303 when: Optional[datetime] = None, 

304 correction=False, 

305 **kwargs, 

306 ) -> list[Message]: 

307 msg["oob"]["url"] = uploaded_url 

308 msg["body"] = uploaded_url 

309 if caption: 

310 m1 = self._send(msg, carbon=carbon, **kwargs) 

311 m2 = self.send_text( 

312 caption, 

313 legacy_msg_id=legacy_msg_id, 

314 when=when, 

315 carbon=carbon, 

316 correction=correction, 

317 **kwargs, 

318 ) 

319 return [m1, m2] if m2 else [m1] 

320 else: 

321 if correction: 

322 msg["replace"]["id"] = self._replace_id(legacy_msg_id) 

323 else: 

324 self._set_msg_id(msg, legacy_msg_id) 

325 return [self._send(msg, carbon=carbon, **kwargs)] 

326 

327 async def send_file( 

328 self, 

329 file_path: Optional[Union[Path, str]] = None, 

330 legacy_msg_id: Optional[LegacyMessageType] = None, 

331 *, 

332 async_data_stream: Optional[AsyncIterator[bytes]] = None, 

333 data_stream: Optional[IO[bytes]] = None, 

334 data: Optional[bytes] = None, 

335 file_url: Optional[str] = None, 

336 file_name: Optional[str] = None, 

337 content_type: Optional[str] = None, 

338 reply_to: Optional[MessageReference] = None, 

339 when: Optional[datetime] = None, 

340 caption: Optional[str] = None, 

341 legacy_file_id: Optional[Union[str, int]] = None, 

342 thread: Optional[LegacyThreadType] = None, 

343 **kwargs, 

344 ) -> tuple[Optional[str], list[Message]]: 

345 """ 

346 Send a single file from this :term:`XMPP Entity`. 

347 

348 :param file_path: Path to the attachment 

349 :param async_data_stream: Alternatively (and ideally) an AsyncIterator yielding bytes 

350 :param data_stream: Alternatively, a stream of bytes (such as a File object) 

351 :param data: Alternatively, a bytes object 

352 :param file_url: Alternatively, a URL 

353 :param file_name: How the file should be named. 

354 :param content_type: MIME type, inferred from filename if not given 

355 :param legacy_msg_id: If you want to be able to transport read markers from the gateway 

356 user to the legacy network, specify this 

357 :param reply_to: Quote another message (:xep:`0461`) 

358 :param when: when the file was sent, for a "delay" tag (:xep:`0203`) 

359 :param caption: an optional text that is linked to the file 

360 :param legacy_file_id: A unique identifier for the file on the legacy network. 

361 Plugins should try their best to provide it, to avoid duplicates. 

362 :param thread: 

363 """ 

364 carbon = kwargs.pop("carbon", False) 

365 mto = kwargs.pop("mto", None) 

366 store_multi = kwargs.pop("store_multi", True) 

367 correction = kwargs.get("correction", False) 

368 if correction and (original_xmpp_id := self._legacy_to_xmpp(legacy_msg_id)): 

369 xmpp_ids = self.xmpp.store.multi.get_xmpp_ids( 

370 self.session.user_pk, original_xmpp_id 

371 ) 

372 

373 for xmpp_id in xmpp_ids: 

374 if xmpp_id == original_xmpp_id: 

375 continue 

376 self.retract(xmpp_id, thread) 

377 

378 if reply_to is not None and reply_to.body: 

379 # We cannot have a "quote fallback" for attachments since most (all?) 

380 # XMPP clients will only treat a message as an attachment if the 

381 # body is the URL and nothing else. 

382 reply_to_for_attachment: MessageReference | None = MessageReference( 

383 reply_to.legacy_id, reply_to.author 

384 ) 

385 else: 

386 reply_to_for_attachment = reply_to 

387 

388 msg = self._make_message( 

389 when=when, 

390 reply_to=reply_to_for_attachment, 

391 carbon=carbon, 

392 mto=mto, 

393 thread=thread, 

394 ) 

395 

396 if content_type is None and (name := (file_name or file_path or file_url)): 

397 content_type, _ = guess_type(name) 

398 

399 is_temp, local_path, new_url = await self.__get_url( 

400 Path(file_path) if file_path else None, 

401 async_data_stream, 

402 data_stream, 

403 data, 

404 file_url, 

405 file_name, 

406 content_type, 

407 legacy_file_id, 

408 ) 

409 

410 if new_url is None: 

411 msg["body"] = ( 

412 "I tried to send a file, but something went wrong. " 

413 "Tell your slidge admin to check the logs." 

414 ) 

415 self._set_msg_id(msg, legacy_msg_id) 

416 return None, [self._send(msg, **kwargs)] 

417 

418 thumbnail = await self.__set_sims( 

419 msg, new_url, local_path, content_type, caption, file_name 

420 ) 

421 self.__set_sfs( 

422 msg, new_url, local_path, content_type, caption, file_name, thumbnail 

423 ) 

424 if is_temp and isinstance(local_path, Path): 

425 local_path.unlink() 

426 local_path.parent.rmdir() 

427 

428 msgs = self.__send_url( 

429 msg, legacy_msg_id, new_url, caption, carbon, when, **kwargs 

430 ) 

431 if store_multi: 

432 self.__store_multi(legacy_msg_id, msgs) 

433 return new_url, msgs 

434 

435 def __send_body( 

436 self, 

437 body: Optional[str] = None, 

438 legacy_msg_id: Optional[LegacyMessageType] = None, 

439 reply_to: Optional[MessageReference] = None, 

440 when: Optional[datetime] = None, 

441 thread: Optional[LegacyThreadType] = None, 

442 **kwargs, 

443 ) -> Optional[Message]: 

444 if body: 

445 return self.send_text( 

446 body, 

447 legacy_msg_id, 

448 reply_to=reply_to, 

449 when=when, 

450 thread=thread, 

451 **kwargs, 

452 ) 

453 else: 

454 return None 

455 

456 async def send_files( 

457 self, 

458 attachments: Collection[LegacyAttachment], 

459 legacy_msg_id: Optional[LegacyMessageType] = None, 

460 body: Optional[str] = None, 

461 *, 

462 reply_to: Optional[MessageReference] = None, 

463 when: Optional[datetime] = None, 

464 thread: Optional[LegacyThreadType] = None, 

465 body_first=False, 

466 correction=False, 

467 correction_event_id: Optional[LegacyMessageType] = None, 

468 **kwargs, 

469 ): 

470 # TODO: once the epic XEP-0385 vs XEP-0447 battle is over, pick 

471 # one and stop sending several attachments this way 

472 # we attach the legacy_message ID to the last message we send, because 

473 # we don't want several messages with the same ID (especially for MUC MAM) 

474 # TODO: refactor this so we limit the number of SQL calls, ie, if 

475 # the legacy file ID is known, only fetch the row once, and if it 

476 # is new, write it all in a single call 

477 if not attachments and not body: 

478 # ignoring empty message 

479 return 

480 send_body = functools.partial( 

481 self.__send_body, 

482 body=body, 

483 reply_to=reply_to, 

484 when=when, 

485 thread=thread, 

486 correction=correction, 

487 legacy_msg_id=legacy_msg_id, 

488 correction_event_id=correction_event_id, 

489 **kwargs, 

490 ) 

491 all_msgs = [] 

492 if body_first: 

493 all_msgs.append(send_body()) 

494 last_attachment_i = len(attachments) - 1 

495 for i, attachment in enumerate(attachments): 

496 last = i == last_attachment_i 

497 if last and not body: 

498 legacy = legacy_msg_id 

499 else: 

500 legacy = None 

501 _url, msgs = await self.send_file( 

502 file_path=attachment.path, 

503 legacy_msg_id=legacy, 

504 file_url=attachment.url, 

505 data_stream=attachment.stream, 

506 data=attachment.data, 

507 reply_to=reply_to, 

508 when=when, 

509 thread=thread, 

510 file_name=attachment.name, 

511 content_type=attachment.content_type, 

512 legacy_file_id=attachment.legacy_file_id, 

513 caption=attachment.caption, 

514 store_multi=False, 

515 **kwargs, 

516 ) 

517 all_msgs.extend(msgs) 

518 if not body_first: 

519 all_msgs.append(send_body()) 

520 self.__store_multi(legacy_msg_id, all_msgs) 

521 

522 def __store_multi( 

523 self, 

524 legacy_msg_id: Optional[LegacyMessageType], 

525 all_msgs: Sequence[Optional[Message]], 

526 ): 

527 if legacy_msg_id is None: 

528 return 

529 ids = [] 

530 for msg in all_msgs: 

531 if not msg: 

532 continue 

533 if stanza_id := msg.get_plugin("stanza_id", check=True): 

534 ids.append(stanza_id["id"]) 

535 else: 

536 ids.append(msg.get_id()) 

537 self.xmpp.store.multi.set_xmpp_ids( 

538 self.session.user_pk, str(legacy_msg_id), ids 

539 ) 

540 

541 

542def get_thumbhash(path: Path) -> tuple[str, int, int]: 

543 with path.open("rb") as fp: 

544 img = Image.open(fp) 

545 width, height = img.size 

546 img = img.convert("RGBA") 

547 if width > 100 or height > 100: 

548 img.thumbnail((100, 100)) 

549 img = ImageOps.exif_transpose(img) 

550 rgba_2d = list(img.getdata()) 

551 rgba = list(chain(*rgba_2d)) 

552 ints = thumbhash.rgba_to_thumb_hash(img.width, img.height, rgba) 

553 return base64.b64encode(bytes(ints)).decode(), width, height 

554 

555 

556log = logging.getLogger(__name__)