Coverage for slidge/core/mixins/attachment.py: 83%

297 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1import base64 

2import functools 

3import logging 

4import os 

5import re 

6import shutil 

7import stat 

8import tempfile 

9import warnings 

10from datetime import datetime 

11from itertools import chain 

12from mimetypes import guess_extension, guess_type 

13from pathlib import Path 

14from typing import Collection, Optional, Sequence, Union 

15from urllib.parse import quote as urlquote 

16from uuid import uuid4 

17from xml.etree import ElementTree as ET 

18 

19import thumbhash 

20from PIL import Image, ImageOps 

21from slixmpp import JID, Message 

22from slixmpp.exceptions import IqError, IqTimeout 

23from slixmpp.plugins.xep_0264.stanza import Thumbnail 

24from slixmpp.plugins.xep_0363 import FileUploadError 

25from slixmpp.plugins.xep_0447.stanza import StatelessFileSharing 

26 

27from ...db.avatar import avatar_cache 

28from ...db.models import Attachment 

29from ...util.types import ( 

30 LegacyAttachment, 

31 LegacyMessageType, 

32 LegacyThreadType, 

33 MessageReference, 

34) 

35from ...util.util import fix_suffix 

36from .. import config 

37from .message_text import TextMessageMixin 

38 

39 

40class AttachmentMixin(TextMessageMixin): 

41 is_group: bool 

42 

43 async def __upload( 

44 self, 

45 file_path: Path, 

46 file_name: Optional[str] = None, 

47 content_type: Optional[str] = None, 

48 ) -> str | None: 

49 if file_name and file_path.name != file_name: 

50 d = Path(tempfile.mkdtemp()) 

51 temp = d / file_name 

52 temp.symlink_to(file_path) 

53 file_path = temp 

54 else: 

55 d = None 

56 if config.UPLOAD_SERVICE: 

57 domain = None 

58 else: 

59 domain = re.sub(r"^.*?\.", "", self.xmpp.boundjid.bare) 

60 try: 

61 new_url = await self.xmpp.plugin["xep_0363"].upload_file( 

62 filename=file_path, 

63 content_type=content_type, 

64 ifrom=config.UPLOAD_REQUESTER or self.xmpp.boundjid, 

65 domain=JID(domain), 

66 ) 

67 except (FileUploadError, IqError, IqTimeout) as e: 

68 warnings.warn(f"Something is wrong with the upload service: {e!r}") 

69 return None 

70 finally: 

71 if d is not None: 

72 file_path.unlink() 

73 d.rmdir() 

74 

75 return new_url 

76 

77 @staticmethod 

78 async def __no_upload( 

79 file_path: Path, 

80 file_name: Optional[str] = None, 

81 legacy_file_id: Optional[Union[str, int]] = None, 

82 ): 

83 file_id = str(uuid4()) if legacy_file_id is None else str(legacy_file_id) 

84 assert config.NO_UPLOAD_PATH is not None 

85 assert config.NO_UPLOAD_URL_PREFIX is not None 

86 destination_dir = Path(config.NO_UPLOAD_PATH) / file_id 

87 

88 if destination_dir.exists(): 

89 log.debug("Dest dir exists: %s", destination_dir) 

90 files = list(f for f in destination_dir.glob("**/*") if f.is_file()) 

91 if len(files) == 1: 

92 log.debug( 

93 "Found the legacy attachment '%s' at '%s'", 

94 legacy_file_id, 

95 files[0], 

96 ) 

97 name = files[0].name 

98 uu = files[0].parent.name # anti-obvious url trick, see below 

99 return files[0], "/".join([file_id, uu, name]) 

100 else: 

101 log.warning( 

102 ( 

103 "There are several or zero files in %s, " 

104 "slidge doesn't know which one to pick among %s. " 

105 "Removing the dir." 

106 ), 

107 destination_dir, 

108 files, 

109 ) 

110 shutil.rmtree(destination_dir) 

111 

112 log.debug("Did not find a file in: %s", destination_dir) 

113 # let's use a UUID to avoid URLs being too obvious 

114 uu = str(uuid4()) 

115 destination_dir = destination_dir / uu 

116 destination_dir.mkdir(parents=True) 

117 

118 name = file_name or file_path.name 

119 destination = destination_dir / name 

120 method = config.NO_UPLOAD_METHOD 

121 if method == "copy": 

122 shutil.copy2(file_path, destination) 

123 elif method == "hardlink": 

124 os.link(file_path, destination) 

125 elif method == "symlink": 

126 os.symlink(file_path, destination, target_is_directory=True) 

127 elif method == "move": 

128 shutil.move(file_path, destination) 

129 else: 

130 raise RuntimeError("No upload method not recognized", method) 

131 

132 if config.NO_UPLOAD_FILE_READ_OTHERS: 

133 log.debug("Changing perms of %s", destination) 

134 destination.chmod(destination.stat().st_mode | stat.S_IROTH) 

135 uploaded_url = "/".join([file_id, uu, name]) 

136 

137 return destination, uploaded_url 

138 

139 async def __valid_url(self, url: str) -> bool: 

140 async with self.session.http.head(url) as r: 

141 return r.status < 400 

142 

143 async def __get_stored(self, attachment: LegacyAttachment) -> Attachment: 

144 if attachment.legacy_file_id is not None: 

145 with self.xmpp.store.session() as orm: 

146 stored = ( 

147 orm.query(Attachment) 

148 .filter_by(legacy_file_id=str(attachment.legacy_file_id)) 

149 .one_or_none() 

150 ) 

151 if stored is not None: 

152 if not await self.__valid_url(stored.url): 

153 stored.url = None # type:ignore 

154 return stored 

155 return Attachment( 

156 user_account_id=None 

157 if self.session is NotImplemented 

158 else self.session.user_pk, 

159 legacy_file_id=None 

160 if attachment.legacy_file_id is None 

161 else str(attachment.legacy_file_id), 

162 url=attachment.url, 

163 ) 

164 

165 async def __get_url( 

166 self, attachment: LegacyAttachment, stored: Attachment 

167 ) -> tuple[bool, Optional[Path], str]: 

168 if attachment.url and config.USE_ATTACHMENT_ORIGINAL_URLS: 

169 return False, None, attachment.url 

170 

171 file_name = attachment.name 

172 content_type = attachment.content_type 

173 file_path = attachment.path 

174 

175 if file_name and len(file_name) > config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH: 

176 log.debug("Trimming long filename: %s", file_name) 

177 base, ext = os.path.splitext(file_name) 

178 file_name = ( 

179 base[: config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH - len(ext)] + ext 

180 ) 

181 

182 if file_path is None: 

183 if file_name is None: 

184 file_name = str(uuid4()) 

185 if content_type is not None: 

186 ext = guess_extension(content_type, strict=False) # type:ignore 

187 if ext is not None: 

188 file_name += ext 

189 temp_dir = Path(tempfile.mkdtemp()) 

190 file_path = temp_dir / file_name 

191 if attachment.url: 

192 async with self.session.http.get(attachment.url) as r: 

193 r.raise_for_status() 

194 with file_path.open("wb") as f: 

195 f.write(await r.read()) 

196 

197 elif attachment.stream is not None: 

198 data = attachment.stream.read() 

199 if data is None: 

200 raise RuntimeError 

201 

202 with file_path.open("wb") as f: 

203 f.write(data) 

204 elif attachment.aio_stream is not None: 

205 # TODO: patch slixmpp to allow this as data source for 

206 # upload_file() so we don't even have to write anything 

207 # to disk. 

208 with file_path.open("wb") as f: 

209 async for chunk in attachment.aio_stream: 

210 f.write(chunk) 

211 elif attachment.data is not None: 

212 with file_path.open("wb") as f: 

213 f.write(attachment.data) 

214 

215 is_temp = not bool(config.NO_UPLOAD_PATH) 

216 else: 

217 is_temp = False 

218 

219 assert isinstance(file_path, Path) 

220 if config.FIX_FILENAME_SUFFIX_MIME_TYPE: 

221 file_name = str(fix_suffix(file_path, content_type, file_name)) 

222 

223 if config.NO_UPLOAD_PATH: 

224 local_path, new_url = await self.__no_upload( 

225 file_path, file_name, stored.legacy_file_id 

226 ) 

227 new_url = (config.NO_UPLOAD_URL_PREFIX or "") + "/" + urlquote(new_url) 

228 else: 

229 local_path = file_path 

230 new_url = await self.__upload(file_path, file_name, content_type) 

231 if stored.legacy_file_id and new_url is not None: 

232 stored.url = new_url 

233 

234 return is_temp, local_path, new_url 

235 

236 async def __set_sims( 

237 self, 

238 msg: Message, 

239 uploaded_url: str, 

240 path: Optional[Path], 

241 attachment: LegacyAttachment, 

242 stored: Attachment, 

243 ) -> Thumbnail | None: 

244 if stored.sims is not None: 

245 ref = self.xmpp["xep_0372"].stanza.Reference(xml=ET.fromstring(stored.sims)) 

246 msg.append(ref) 

247 if ref["sims"]["file"].get_plugin("thumbnail", check=True): 

248 return ref["sims"]["file"]["thumbnail"] 

249 else: 

250 return None 

251 

252 if not path: 

253 return None 

254 

255 ref = self.xmpp["xep_0385"].get_sims( 

256 path, [uploaded_url], attachment.content_type, attachment.caption 

257 ) 

258 if attachment.name: 

259 ref["sims"]["file"]["name"] = attachment.name 

260 thumbnail = None 

261 if attachment.content_type is not None and attachment.content_type.startswith( 

262 "image" 

263 ): 

264 try: 

265 h, x, y = await self.xmpp.loop.run_in_executor( 

266 avatar_cache._thread_pool, get_thumbhash, path 

267 ) 

268 except Exception as e: 

269 log.debug("Could not generate a thumbhash", exc_info=e) 

270 else: 

271 thumbnail = ref["sims"]["file"]["thumbnail"] 

272 thumbnail["width"] = x 

273 thumbnail["height"] = y 

274 thumbnail["media-type"] = "image/thumbhash" 

275 thumbnail["uri"] = "data:image/thumbhash;base64," + urlquote(h) 

276 

277 stored.sims = str(ref) 

278 msg.append(ref) 

279 

280 return thumbnail 

281 

282 def __set_sfs( 

283 self, 

284 msg: Message, 

285 uploaded_url: str, 

286 path: Optional[Path], 

287 attachment: LegacyAttachment, 

288 stored: Attachment, 

289 thumbnail: Optional[Thumbnail] = None, 

290 ) -> None: 

291 if stored.sfs is not None: 

292 msg.append(StatelessFileSharing(xml=ET.fromstring(stored.sfs))) 

293 return 

294 

295 if not path: 

296 return 

297 

298 sfs = self.xmpp["xep_0447"].get_sfs( 

299 path, [uploaded_url], attachment.content_type, attachment.caption 

300 ) 

301 if attachment.name: 

302 sfs["file"]["name"] = attachment.name 

303 if thumbnail is not None: 

304 sfs["file"].append(thumbnail) 

305 stored.sfs = str(sfs) 

306 msg.append(sfs) 

307 

308 def __send_url( 

309 self, 

310 msg: Message, 

311 legacy_msg_id: LegacyMessageType, 

312 uploaded_url: str, 

313 caption: Optional[str] = None, 

314 carbon: bool = False, 

315 when: Optional[datetime] = None, 

316 correction: bool = False, 

317 **kwargs, 

318 ) -> list[Message]: 

319 msg["oob"]["url"] = uploaded_url 

320 msg["body"] = uploaded_url 

321 if caption: 

322 m1 = self._send(msg, carbon=carbon, **kwargs) 

323 m2 = self.send_text( 

324 caption, 

325 legacy_msg_id=legacy_msg_id, 

326 when=when, 

327 carbon=carbon, 

328 correction=correction, 

329 **kwargs, 

330 ) 

331 return [m1, m2] if m2 else [m1] 

332 else: 

333 if correction: 

334 msg["replace"]["id"] = self._replace_id(legacy_msg_id) 

335 else: 

336 self._set_msg_id(msg, legacy_msg_id) 

337 return [self._send(msg, carbon=carbon, **kwargs)] 

338 

339 def __get_base_message( 

340 self, 

341 legacy_msg_id: Optional[LegacyMessageType] = None, 

342 reply_to: Optional[MessageReference] = None, 

343 when: Optional[datetime] = None, 

344 thread: Optional[LegacyThreadType] = None, 

345 carbon: bool = False, 

346 correction: bool = False, 

347 mto: Optional[JID] = None, 

348 ) -> Message: 

349 if correction and (original_xmpp_id := self._legacy_to_xmpp(legacy_msg_id)): 

350 with self.xmpp.store.session() as orm: 

351 xmpp_ids = self.xmpp.store.id_map.get_xmpp( 

352 orm, self._recipient_pk(), str(legacy_msg_id), self.is_group 

353 ) 

354 

355 for xmpp_id in xmpp_ids: 

356 if xmpp_id == original_xmpp_id: 

357 continue 

358 self.retract(xmpp_id, thread) 

359 

360 if reply_to is not None and reply_to.body: 

361 # We cannot have a "quote fallback" for attachments since most (all?) 

362 # XMPP clients will only treat a message as an attachment if the 

363 # body is the URL and nothing else. 

364 reply_to_for_attachment: MessageReference | None = MessageReference( 

365 reply_to.legacy_id, reply_to.author 

366 ) 

367 else: 

368 reply_to_for_attachment = reply_to 

369 

370 return self._make_message( 

371 when=when, 

372 reply_to=reply_to_for_attachment, 

373 carbon=carbon, 

374 mto=mto, 

375 thread=thread, 

376 ) 

377 

378 async def send_file( 

379 self, 

380 attachment: LegacyAttachment | Path | str, 

381 legacy_msg_id: Optional[LegacyMessageType] = None, 

382 *, 

383 reply_to: Optional[MessageReference] = None, 

384 when: Optional[datetime] = None, 

385 thread: Optional[LegacyThreadType] = None, 

386 **kwargs, 

387 ) -> tuple[Optional[str], list[Message]]: 

388 """ 

389 Send a single file from this :term:`XMPP Entity`. 

390 

391 :param attachment: The file to send. 

392 Ideally, a :class:`.LegacyAttachment` with a unique ``legacy_file_id`` 

393 attribute set, to optimise potential future reuses. 

394 It can also be: 

395 - a :class:`pathlib.Path` instance to point to a local file, or 

396 - a ``str``, representing a fetchable HTTP URL. 

397 :param legacy_msg_id: If you want to be able to transport read markers from the gateway 

398 user to the legacy network, specify this 

399 :param reply_to: Quote another message (:xep:`0461`) 

400 :param when: when the file was sent, for a "delay" tag (:xep:`0203`) 

401 :param thread: 

402 """ 

403 store_multi = kwargs.pop("store_multi", True) 

404 carbon = kwargs.pop("carbon", False) 

405 mto = kwargs.pop("mto", None) 

406 correction = kwargs.get("correction", False) 

407 

408 msg = self.__get_base_message( 

409 legacy_msg_id, reply_to, when, thread, carbon, correction, mto 

410 ) 

411 

412 if isinstance(attachment, str): 

413 attachment = LegacyAttachment(url=attachment) 

414 elif isinstance(attachment, Path): 

415 attachment = LegacyAttachment(path=attachment) 

416 

417 stored = await self.__get_stored(attachment) 

418 

419 if attachment.content_type is None and ( 

420 name := (attachment.name or attachment.url or attachment.path) 

421 ): 

422 attachment.content_type, _ = guess_type(name) 

423 

424 if stored.url: 

425 is_temp = False 

426 local_path = None 

427 new_url = stored.url 

428 else: 

429 is_temp, local_path, new_url = await self.__get_url(attachment, stored) 

430 if new_url is None: 

431 msg["body"] = ( 

432 "I tried to send a file, but something went wrong. " 

433 "Tell your slidge admin to check the logs." 

434 ) 

435 self._set_msg_id(msg, legacy_msg_id) 

436 return None, [self._send(msg, **kwargs)] 

437 

438 stored.url = new_url 

439 thumbnail = await self.__set_sims(msg, new_url, local_path, attachment, stored) 

440 self.__set_sfs(msg, new_url, local_path, attachment, stored, thumbnail) 

441 

442 if self.session is not NotImplemented: 

443 with self.xmpp.store.session(expire_on_commit=False) as orm: 

444 orm.add(stored) 

445 orm.commit() 

446 

447 if is_temp and isinstance(local_path, Path): 

448 local_path.unlink() 

449 local_path.parent.rmdir() 

450 

451 msgs = self.__send_url( 

452 msg, legacy_msg_id, new_url, attachment.caption, carbon, when, **kwargs 

453 ) 

454 if self.session is not NotImplemented: 

455 if store_multi: 

456 self.__store_multi(legacy_msg_id, msgs) 

457 return new_url, msgs 

458 

459 def __send_body( 

460 self, 

461 body: Optional[str] = None, 

462 legacy_msg_id: Optional[LegacyMessageType] = None, 

463 reply_to: Optional[MessageReference] = None, 

464 when: Optional[datetime] = None, 

465 thread: Optional[LegacyThreadType] = None, 

466 **kwargs, 

467 ) -> Optional[Message]: 

468 if body: 

469 return self.send_text( 

470 body, 

471 legacy_msg_id, 

472 reply_to=reply_to, 

473 when=when, 

474 thread=thread, 

475 **kwargs, 

476 ) 

477 else: 

478 return None 

479 

480 async def send_files( 

481 self, 

482 attachments: Collection[LegacyAttachment], 

483 legacy_msg_id: Optional[LegacyMessageType] = None, 

484 body: Optional[str] = None, 

485 *, 

486 reply_to: Optional[MessageReference] = None, 

487 when: Optional[datetime] = None, 

488 thread: Optional[LegacyThreadType] = None, 

489 body_first: bool = False, 

490 correction: bool = False, 

491 correction_event_id: Optional[LegacyMessageType] = None, 

492 **kwargs, 

493 ) -> None: 

494 # TODO: once the epic XEP-0385 vs XEP-0447 battle is over, pick 

495 # one and stop sending several attachments this way 

496 # we attach the legacy_message ID to the last message we send, because 

497 # we don't want several messages with the same ID (especially for MUC MAM) 

498 if not attachments and not body: 

499 # ignoring empty message 

500 return 

501 send_body = functools.partial( 

502 self.__send_body, 

503 body=body, 

504 reply_to=reply_to, 

505 when=when, 

506 thread=thread, 

507 correction=correction, 

508 legacy_msg_id=legacy_msg_id, 

509 correction_event_id=correction_event_id, 

510 **kwargs, 

511 ) 

512 all_msgs = [] 

513 if body_first: 

514 all_msgs.append(send_body()) 

515 last_attachment_i = len(attachments) - 1 

516 for i, attachment in enumerate(attachments): 

517 last = i == last_attachment_i 

518 if last and not body: 

519 legacy = legacy_msg_id 

520 else: 

521 legacy = None 

522 _url, msgs = await self.send_file( 

523 attachment, 

524 legacy, 

525 reply_to=reply_to, 

526 when=when, 

527 thread=thread, 

528 store_multi=False, 

529 **kwargs, 

530 ) 

531 all_msgs.extend(msgs) 

532 if not body_first: 

533 all_msgs.append(send_body()) 

534 self.__store_multi(legacy_msg_id, all_msgs) 

535 

536 def __store_multi( 

537 self, 

538 legacy_msg_id: Optional[LegacyMessageType], 

539 all_msgs: Sequence[Optional[Message]], 

540 ) -> None: 

541 if legacy_msg_id is None: 

542 return 

543 ids = [] 

544 for msg in all_msgs: 

545 if not msg: 

546 continue 

547 if stanza_id := msg.get_plugin("stanza_id", check=True): 

548 ids.append(stanza_id["id"]) 

549 else: 

550 ids.append(msg.get_id()) 

551 with self.xmpp.store.session() as orm: 

552 self.xmpp.store.id_map.set_msg( 

553 orm, self._recipient_pk(), str(legacy_msg_id), ids, self.is_group 

554 ) 

555 orm.commit() 

556 

557 

558def get_thumbhash(path: Path) -> tuple[str, int, int]: 

559 with path.open("rb") as fp: 

560 img = Image.open(fp) 

561 width, height = img.size 

562 img = img.convert("RGBA") 

563 if width > 100 or height > 100: 

564 img.thumbnail((100, 100)) 

565 img = ImageOps.exif_transpose(img) 

566 rgba_2d = list(img.getdata()) 

567 rgba = list(chain(*rgba_2d)) 

568 ints = thumbhash.rgba_to_thumb_hash(img.width, img.height, rgba) 

569 return base64.b64encode(bytes(ints)).decode(), width, height 

570 

571 

572log = logging.getLogger(__name__)