Coverage for slidge/core/mixins/attachment.py: 83%
297 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1import base64
2import functools
3import logging
4import os
5import re
6import shutil
7import stat
8import tempfile
9import warnings
10from datetime import datetime
11from itertools import chain
12from mimetypes import guess_extension, guess_type
13from pathlib import Path
14from typing import Collection, Optional, Sequence, Union
15from urllib.parse import quote as urlquote
16from uuid import uuid4
17from xml.etree import ElementTree as ET
19import thumbhash
20from PIL import Image, ImageOps
21from slixmpp import JID, Message
22from slixmpp.exceptions import IqError, IqTimeout
23from slixmpp.plugins.xep_0264.stanza import Thumbnail
24from slixmpp.plugins.xep_0363 import FileUploadError
25from slixmpp.plugins.xep_0447.stanza import StatelessFileSharing
27from ...db.avatar import avatar_cache
28from ...db.models import Attachment
29from ...util.types import (
30 LegacyAttachment,
31 LegacyMessageType,
32 LegacyThreadType,
33 MessageReference,
34)
35from ...util.util import fix_suffix
36from .. import config
37from .message_text import TextMessageMixin
40class AttachmentMixin(TextMessageMixin):
41 is_group: bool
43 async def __upload(
44 self,
45 file_path: Path,
46 file_name: Optional[str] = None,
47 content_type: Optional[str] = None,
48 ) -> str | None:
49 if file_name and file_path.name != file_name:
50 d = Path(tempfile.mkdtemp())
51 temp = d / file_name
52 temp.symlink_to(file_path)
53 file_path = temp
54 else:
55 d = None
56 if config.UPLOAD_SERVICE:
57 domain = None
58 else:
59 domain = re.sub(r"^.*?\.", "", self.xmpp.boundjid.bare)
60 try:
61 new_url = await self.xmpp.plugin["xep_0363"].upload_file(
62 filename=file_path,
63 content_type=content_type,
64 ifrom=config.UPLOAD_REQUESTER or self.xmpp.boundjid,
65 domain=JID(domain),
66 )
67 except (FileUploadError, IqError, IqTimeout) as e:
68 warnings.warn(f"Something is wrong with the upload service: {e!r}")
69 return None
70 finally:
71 if d is not None:
72 file_path.unlink()
73 d.rmdir()
75 return new_url
77 @staticmethod
78 async def __no_upload(
79 file_path: Path,
80 file_name: Optional[str] = None,
81 legacy_file_id: Optional[Union[str, int]] = None,
82 ):
83 file_id = str(uuid4()) if legacy_file_id is None else str(legacy_file_id)
84 assert config.NO_UPLOAD_PATH is not None
85 assert config.NO_UPLOAD_URL_PREFIX is not None
86 destination_dir = Path(config.NO_UPLOAD_PATH) / file_id
88 if destination_dir.exists():
89 log.debug("Dest dir exists: %s", destination_dir)
90 files = list(f for f in destination_dir.glob("**/*") if f.is_file())
91 if len(files) == 1:
92 log.debug(
93 "Found the legacy attachment '%s' at '%s'",
94 legacy_file_id,
95 files[0],
96 )
97 name = files[0].name
98 uu = files[0].parent.name # anti-obvious url trick, see below
99 return files[0], "/".join([file_id, uu, name])
100 else:
101 log.warning(
102 (
103 "There are several or zero files in %s, "
104 "slidge doesn't know which one to pick among %s. "
105 "Removing the dir."
106 ),
107 destination_dir,
108 files,
109 )
110 shutil.rmtree(destination_dir)
112 log.debug("Did not find a file in: %s", destination_dir)
113 # let's use a UUID to avoid URLs being too obvious
114 uu = str(uuid4())
115 destination_dir = destination_dir / uu
116 destination_dir.mkdir(parents=True)
118 name = file_name or file_path.name
119 destination = destination_dir / name
120 method = config.NO_UPLOAD_METHOD
121 if method == "copy":
122 shutil.copy2(file_path, destination)
123 elif method == "hardlink":
124 os.link(file_path, destination)
125 elif method == "symlink":
126 os.symlink(file_path, destination, target_is_directory=True)
127 elif method == "move":
128 shutil.move(file_path, destination)
129 else:
130 raise RuntimeError("No upload method not recognized", method)
132 if config.NO_UPLOAD_FILE_READ_OTHERS:
133 log.debug("Changing perms of %s", destination)
134 destination.chmod(destination.stat().st_mode | stat.S_IROTH)
135 uploaded_url = "/".join([file_id, uu, name])
137 return destination, uploaded_url
139 async def __valid_url(self, url: str) -> bool:
140 async with self.session.http.head(url) as r:
141 return r.status < 400
143 async def __get_stored(self, attachment: LegacyAttachment) -> Attachment:
144 if attachment.legacy_file_id is not None:
145 with self.xmpp.store.session() as orm:
146 stored = (
147 orm.query(Attachment)
148 .filter_by(legacy_file_id=str(attachment.legacy_file_id))
149 .one_or_none()
150 )
151 if stored is not None:
152 if not await self.__valid_url(stored.url):
153 stored.url = None # type:ignore
154 return stored
155 return Attachment(
156 user_account_id=None
157 if self.session is NotImplemented
158 else self.session.user_pk,
159 legacy_file_id=None
160 if attachment.legacy_file_id is None
161 else str(attachment.legacy_file_id),
162 url=attachment.url,
163 )
165 async def __get_url(
166 self, attachment: LegacyAttachment, stored: Attachment
167 ) -> tuple[bool, Optional[Path], str]:
168 if attachment.url and config.USE_ATTACHMENT_ORIGINAL_URLS:
169 return False, None, attachment.url
171 file_name = attachment.name
172 content_type = attachment.content_type
173 file_path = attachment.path
175 if file_name and len(file_name) > config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH:
176 log.debug("Trimming long filename: %s", file_name)
177 base, ext = os.path.splitext(file_name)
178 file_name = (
179 base[: config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH - len(ext)] + ext
180 )
182 if file_path is None:
183 if file_name is None:
184 file_name = str(uuid4())
185 if content_type is not None:
186 ext = guess_extension(content_type, strict=False) # type:ignore
187 if ext is not None:
188 file_name += ext
189 temp_dir = Path(tempfile.mkdtemp())
190 file_path = temp_dir / file_name
191 if attachment.url:
192 async with self.session.http.get(attachment.url) as r:
193 r.raise_for_status()
194 with file_path.open("wb") as f:
195 f.write(await r.read())
197 elif attachment.stream is not None:
198 data = attachment.stream.read()
199 if data is None:
200 raise RuntimeError
202 with file_path.open("wb") as f:
203 f.write(data)
204 elif attachment.aio_stream is not None:
205 # TODO: patch slixmpp to allow this as data source for
206 # upload_file() so we don't even have to write anything
207 # to disk.
208 with file_path.open("wb") as f:
209 async for chunk in attachment.aio_stream:
210 f.write(chunk)
211 elif attachment.data is not None:
212 with file_path.open("wb") as f:
213 f.write(attachment.data)
215 is_temp = not bool(config.NO_UPLOAD_PATH)
216 else:
217 is_temp = False
219 assert isinstance(file_path, Path)
220 if config.FIX_FILENAME_SUFFIX_MIME_TYPE:
221 file_name = str(fix_suffix(file_path, content_type, file_name))
223 if config.NO_UPLOAD_PATH:
224 local_path, new_url = await self.__no_upload(
225 file_path, file_name, stored.legacy_file_id
226 )
227 new_url = (config.NO_UPLOAD_URL_PREFIX or "") + "/" + urlquote(new_url)
228 else:
229 local_path = file_path
230 new_url = await self.__upload(file_path, file_name, content_type)
231 if stored.legacy_file_id and new_url is not None:
232 stored.url = new_url
234 return is_temp, local_path, new_url
236 async def __set_sims(
237 self,
238 msg: Message,
239 uploaded_url: str,
240 path: Optional[Path],
241 attachment: LegacyAttachment,
242 stored: Attachment,
243 ) -> Thumbnail | None:
244 if stored.sims is not None:
245 ref = self.xmpp["xep_0372"].stanza.Reference(xml=ET.fromstring(stored.sims))
246 msg.append(ref)
247 if ref["sims"]["file"].get_plugin("thumbnail", check=True):
248 return ref["sims"]["file"]["thumbnail"]
249 else:
250 return None
252 if not path:
253 return None
255 ref = self.xmpp["xep_0385"].get_sims(
256 path, [uploaded_url], attachment.content_type, attachment.caption
257 )
258 if attachment.name:
259 ref["sims"]["file"]["name"] = attachment.name
260 thumbnail = None
261 if attachment.content_type is not None and attachment.content_type.startswith(
262 "image"
263 ):
264 try:
265 h, x, y = await self.xmpp.loop.run_in_executor(
266 avatar_cache._thread_pool, get_thumbhash, path
267 )
268 except Exception as e:
269 log.debug("Could not generate a thumbhash", exc_info=e)
270 else:
271 thumbnail = ref["sims"]["file"]["thumbnail"]
272 thumbnail["width"] = x
273 thumbnail["height"] = y
274 thumbnail["media-type"] = "image/thumbhash"
275 thumbnail["uri"] = "data:image/thumbhash;base64," + urlquote(h)
277 stored.sims = str(ref)
278 msg.append(ref)
280 return thumbnail
282 def __set_sfs(
283 self,
284 msg: Message,
285 uploaded_url: str,
286 path: Optional[Path],
287 attachment: LegacyAttachment,
288 stored: Attachment,
289 thumbnail: Optional[Thumbnail] = None,
290 ) -> None:
291 if stored.sfs is not None:
292 msg.append(StatelessFileSharing(xml=ET.fromstring(stored.sfs)))
293 return
295 if not path:
296 return
298 sfs = self.xmpp["xep_0447"].get_sfs(
299 path, [uploaded_url], attachment.content_type, attachment.caption
300 )
301 if attachment.name:
302 sfs["file"]["name"] = attachment.name
303 if thumbnail is not None:
304 sfs["file"].append(thumbnail)
305 stored.sfs = str(sfs)
306 msg.append(sfs)
308 def __send_url(
309 self,
310 msg: Message,
311 legacy_msg_id: LegacyMessageType,
312 uploaded_url: str,
313 caption: Optional[str] = None,
314 carbon: bool = False,
315 when: Optional[datetime] = None,
316 correction: bool = False,
317 **kwargs,
318 ) -> list[Message]:
319 msg["oob"]["url"] = uploaded_url
320 msg["body"] = uploaded_url
321 if caption:
322 m1 = self._send(msg, carbon=carbon, **kwargs)
323 m2 = self.send_text(
324 caption,
325 legacy_msg_id=legacy_msg_id,
326 when=when,
327 carbon=carbon,
328 correction=correction,
329 **kwargs,
330 )
331 return [m1, m2] if m2 else [m1]
332 else:
333 if correction:
334 msg["replace"]["id"] = self._replace_id(legacy_msg_id)
335 else:
336 self._set_msg_id(msg, legacy_msg_id)
337 return [self._send(msg, carbon=carbon, **kwargs)]
339 def __get_base_message(
340 self,
341 legacy_msg_id: Optional[LegacyMessageType] = None,
342 reply_to: Optional[MessageReference] = None,
343 when: Optional[datetime] = None,
344 thread: Optional[LegacyThreadType] = None,
345 carbon: bool = False,
346 correction: bool = False,
347 mto: Optional[JID] = None,
348 ) -> Message:
349 if correction and (original_xmpp_id := self._legacy_to_xmpp(legacy_msg_id)):
350 with self.xmpp.store.session() as orm:
351 xmpp_ids = self.xmpp.store.id_map.get_xmpp(
352 orm, self._recipient_pk(), str(legacy_msg_id), self.is_group
353 )
355 for xmpp_id in xmpp_ids:
356 if xmpp_id == original_xmpp_id:
357 continue
358 self.retract(xmpp_id, thread)
360 if reply_to is not None and reply_to.body:
361 # We cannot have a "quote fallback" for attachments since most (all?)
362 # XMPP clients will only treat a message as an attachment if the
363 # body is the URL and nothing else.
364 reply_to_for_attachment: MessageReference | None = MessageReference(
365 reply_to.legacy_id, reply_to.author
366 )
367 else:
368 reply_to_for_attachment = reply_to
370 return self._make_message(
371 when=when,
372 reply_to=reply_to_for_attachment,
373 carbon=carbon,
374 mto=mto,
375 thread=thread,
376 )
378 async def send_file(
379 self,
380 attachment: LegacyAttachment | Path | str,
381 legacy_msg_id: Optional[LegacyMessageType] = None,
382 *,
383 reply_to: Optional[MessageReference] = None,
384 when: Optional[datetime] = None,
385 thread: Optional[LegacyThreadType] = None,
386 **kwargs,
387 ) -> tuple[Optional[str], list[Message]]:
388 """
389 Send a single file from this :term:`XMPP Entity`.
391 :param attachment: The file to send.
392 Ideally, a :class:`.LegacyAttachment` with a unique ``legacy_file_id``
393 attribute set, to optimise potential future reuses.
394 It can also be:
395 - a :class:`pathlib.Path` instance to point to a local file, or
396 - a ``str``, representing a fetchable HTTP URL.
397 :param legacy_msg_id: If you want to be able to transport read markers from the gateway
398 user to the legacy network, specify this
399 :param reply_to: Quote another message (:xep:`0461`)
400 :param when: when the file was sent, for a "delay" tag (:xep:`0203`)
401 :param thread:
402 """
403 store_multi = kwargs.pop("store_multi", True)
404 carbon = kwargs.pop("carbon", False)
405 mto = kwargs.pop("mto", None)
406 correction = kwargs.get("correction", False)
408 msg = self.__get_base_message(
409 legacy_msg_id, reply_to, when, thread, carbon, correction, mto
410 )
412 if isinstance(attachment, str):
413 attachment = LegacyAttachment(url=attachment)
414 elif isinstance(attachment, Path):
415 attachment = LegacyAttachment(path=attachment)
417 stored = await self.__get_stored(attachment)
419 if attachment.content_type is None and (
420 name := (attachment.name or attachment.url or attachment.path)
421 ):
422 attachment.content_type, _ = guess_type(name)
424 if stored.url:
425 is_temp = False
426 local_path = None
427 new_url = stored.url
428 else:
429 is_temp, local_path, new_url = await self.__get_url(attachment, stored)
430 if new_url is None:
431 msg["body"] = (
432 "I tried to send a file, but something went wrong. "
433 "Tell your slidge admin to check the logs."
434 )
435 self._set_msg_id(msg, legacy_msg_id)
436 return None, [self._send(msg, **kwargs)]
438 stored.url = new_url
439 thumbnail = await self.__set_sims(msg, new_url, local_path, attachment, stored)
440 self.__set_sfs(msg, new_url, local_path, attachment, stored, thumbnail)
442 if self.session is not NotImplemented:
443 with self.xmpp.store.session(expire_on_commit=False) as orm:
444 orm.add(stored)
445 orm.commit()
447 if is_temp and isinstance(local_path, Path):
448 local_path.unlink()
449 local_path.parent.rmdir()
451 msgs = self.__send_url(
452 msg, legacy_msg_id, new_url, attachment.caption, carbon, when, **kwargs
453 )
454 if self.session is not NotImplemented:
455 if store_multi:
456 self.__store_multi(legacy_msg_id, msgs)
457 return new_url, msgs
459 def __send_body(
460 self,
461 body: Optional[str] = None,
462 legacy_msg_id: Optional[LegacyMessageType] = None,
463 reply_to: Optional[MessageReference] = None,
464 when: Optional[datetime] = None,
465 thread: Optional[LegacyThreadType] = None,
466 **kwargs,
467 ) -> Optional[Message]:
468 if body:
469 return self.send_text(
470 body,
471 legacy_msg_id,
472 reply_to=reply_to,
473 when=when,
474 thread=thread,
475 **kwargs,
476 )
477 else:
478 return None
480 async def send_files(
481 self,
482 attachments: Collection[LegacyAttachment],
483 legacy_msg_id: Optional[LegacyMessageType] = None,
484 body: Optional[str] = None,
485 *,
486 reply_to: Optional[MessageReference] = None,
487 when: Optional[datetime] = None,
488 thread: Optional[LegacyThreadType] = None,
489 body_first: bool = False,
490 correction: bool = False,
491 correction_event_id: Optional[LegacyMessageType] = None,
492 **kwargs,
493 ) -> None:
494 # TODO: once the epic XEP-0385 vs XEP-0447 battle is over, pick
495 # one and stop sending several attachments this way
496 # we attach the legacy_message ID to the last message we send, because
497 # we don't want several messages with the same ID (especially for MUC MAM)
498 if not attachments and not body:
499 # ignoring empty message
500 return
501 send_body = functools.partial(
502 self.__send_body,
503 body=body,
504 reply_to=reply_to,
505 when=when,
506 thread=thread,
507 correction=correction,
508 legacy_msg_id=legacy_msg_id,
509 correction_event_id=correction_event_id,
510 **kwargs,
511 )
512 all_msgs = []
513 if body_first:
514 all_msgs.append(send_body())
515 last_attachment_i = len(attachments) - 1
516 for i, attachment in enumerate(attachments):
517 last = i == last_attachment_i
518 if last and not body:
519 legacy = legacy_msg_id
520 else:
521 legacy = None
522 _url, msgs = await self.send_file(
523 attachment,
524 legacy,
525 reply_to=reply_to,
526 when=when,
527 thread=thread,
528 store_multi=False,
529 **kwargs,
530 )
531 all_msgs.extend(msgs)
532 if not body_first:
533 all_msgs.append(send_body())
534 self.__store_multi(legacy_msg_id, all_msgs)
536 def __store_multi(
537 self,
538 legacy_msg_id: Optional[LegacyMessageType],
539 all_msgs: Sequence[Optional[Message]],
540 ) -> None:
541 if legacy_msg_id is None:
542 return
543 ids = []
544 for msg in all_msgs:
545 if not msg:
546 continue
547 if stanza_id := msg.get_plugin("stanza_id", check=True):
548 ids.append(stanza_id["id"])
549 else:
550 ids.append(msg.get_id())
551 with self.xmpp.store.session() as orm:
552 self.xmpp.store.id_map.set_msg(
553 orm, self._recipient_pk(), str(legacy_msg_id), ids, self.is_group
554 )
555 orm.commit()
558def get_thumbhash(path: Path) -> tuple[str, int, int]:
559 with path.open("rb") as fp:
560 img = Image.open(fp)
561 width, height = img.size
562 img = img.convert("RGBA")
563 if width > 100 or height > 100:
564 img.thumbnail((100, 100))
565 img = ImageOps.exif_transpose(img)
566 rgba_2d = list(img.getdata())
567 rgba = list(chain(*rgba_2d))
568 ints = thumbhash.rgba_to_thumb_hash(img.width, img.height, rgba)
569 return base64.b64encode(bytes(ints)).decode(), width, height
572log = logging.getLogger(__name__)