Coverage for slidge/core/mixins/attachment.py: 82%
270 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1import base64
2import functools
3import logging
4import os
5import re
6import shutil
7import stat
8import tempfile
9import warnings
10from datetime import datetime
11from itertools import chain
12from mimetypes import guess_extension, guess_type
13from pathlib import Path
14from typing import IO, AsyncIterator, Collection, Optional, Sequence, Union
15from urllib.parse import quote as urlquote
16from uuid import uuid4
17from xml.etree import ElementTree as ET
19import thumbhash
20from PIL import Image, ImageOps
21from slixmpp import JID, Message
22from slixmpp.exceptions import IqError, IqTimeout
23from slixmpp.plugins.xep_0363 import FileUploadError
24from slixmpp.plugins.xep_0447.stanza import StatelessFileSharing
26from ...db.avatar import avatar_cache
27from ...slixfix.xep_0264.stanza import Thumbnail
28from ...util.types import (
29 LegacyAttachment,
30 LegacyMessageType,
31 LegacyThreadType,
32 MessageReference,
33)
34from ...util.util import fix_suffix
35from .. import config
36from .message_text import TextMessageMixin
39class AttachmentMixin(TextMessageMixin):
40 def __init__(self, *a, **kw):
41 super().__init__(*a, **kw)
42 self.__store = self.xmpp.store.attachments
44 async def __upload(
45 self,
46 file_path: Path,
47 file_name: Optional[str] = None,
48 content_type: Optional[str] = None,
49 ):
50 if file_name and file_path.name != file_name:
51 d = Path(tempfile.mkdtemp())
52 temp = d / file_name
53 temp.symlink_to(file_path)
54 file_path = temp
55 else:
56 d = None
57 if config.UPLOAD_SERVICE:
58 domain = None
59 else:
60 domain = re.sub(r"^.*?\.", "", self.xmpp.boundjid.bare)
61 try:
62 new_url = await self.xmpp.plugin["xep_0363"].upload_file(
63 filename=file_path,
64 content_type=content_type,
65 ifrom=config.UPLOAD_REQUESTER or self.xmpp.boundjid,
66 domain=JID(domain),
67 )
68 except (FileUploadError, IqError, IqTimeout) as e:
69 warnings.warn(f"Something is wrong with the upload service: {e!r}")
70 return None
71 finally:
72 if d is not None:
73 file_path.unlink()
74 d.rmdir()
76 return new_url
78 @staticmethod
79 async def __no_upload(
80 file_path: Path,
81 file_name: Optional[str] = None,
82 legacy_file_id: Optional[Union[str, int]] = None,
83 ):
84 file_id = str(uuid4()) if legacy_file_id is None else str(legacy_file_id)
85 assert config.NO_UPLOAD_PATH is not None
86 assert config.NO_UPLOAD_URL_PREFIX is not None
87 destination_dir = Path(config.NO_UPLOAD_PATH) / file_id
89 if destination_dir.exists():
90 log.debug("Dest dir exists: %s", destination_dir)
91 files = list(f for f in destination_dir.glob("**/*") if f.is_file())
92 if len(files) == 1:
93 log.debug(
94 "Found the legacy attachment '%s' at '%s'",
95 legacy_file_id,
96 files[0],
97 )
98 name = files[0].name
99 uu = files[0].parent.name # anti-obvious url trick, see below
100 return files[0], "/".join([file_id, uu, name])
101 else:
102 log.warning(
103 (
104 "There are several or zero files in %s, "
105 "slidge doesn't know which one to pick among %s. "
106 "Removing the dir."
107 ),
108 destination_dir,
109 files,
110 )
111 shutil.rmtree(destination_dir)
113 log.debug("Did not find a file in: %s", destination_dir)
114 # let's use a UUID to avoid URLs being too obvious
115 uu = str(uuid4())
116 destination_dir = destination_dir / uu
117 destination_dir.mkdir(parents=True)
119 name = file_name or file_path.name
120 destination = destination_dir / name
121 method = config.NO_UPLOAD_METHOD
122 if method == "copy":
123 shutil.copy2(file_path, destination)
124 elif method == "hardlink":
125 os.link(file_path, destination)
126 elif method == "symlink":
127 os.symlink(file_path, destination, target_is_directory=True)
128 elif method == "move":
129 shutil.move(file_path, destination)
130 else:
131 raise RuntimeError("No upload method not recognized", method)
133 if config.NO_UPLOAD_FILE_READ_OTHERS:
134 log.debug("Changing perms of %s", destination)
135 destination.chmod(destination.stat().st_mode | stat.S_IROTH)
136 uploaded_url = "/".join([file_id, uu, name])
138 return destination, uploaded_url
140 async def __get_url(
141 self,
142 file_path: Optional[Path] = None,
143 async_data_stream: Optional[AsyncIterator[bytes]] = None,
144 data_stream: Optional[IO[bytes]] = None,
145 data: Optional[bytes] = None,
146 file_url: Optional[str] = None,
147 file_name: Optional[str] = None,
148 content_type: Optional[str] = None,
149 legacy_file_id: Optional[Union[str, int]] = None,
150 ) -> tuple[bool, Optional[Path], str]:
151 if legacy_file_id:
152 cache = self.__store.get_url(str(legacy_file_id))
153 if cache is not None:
154 async with self.session.http.head(cache) as r:
155 if r.status < 400:
156 return False, None, cache
157 else:
158 self.__store.remove(str(legacy_file_id))
160 if file_url and config.USE_ATTACHMENT_ORIGINAL_URLS:
161 return False, None, file_url
163 if file_name and len(file_name) > config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH:
164 log.debug("Trimming long filename: %s", file_name)
165 base, ext = os.path.splitext(file_name)
166 file_name = (
167 base[: config.ATTACHMENT_MAXIMUM_FILE_NAME_LENGTH - len(ext)] + ext
168 )
170 if file_path is None:
171 if file_name is None:
172 file_name = str(uuid4())
173 if content_type is not None:
174 ext = guess_extension(content_type, strict=False) # type:ignore
175 if ext is not None:
176 file_name += ext
177 temp_dir = Path(tempfile.mkdtemp())
178 file_path = temp_dir / file_name
179 if file_url:
180 async with self.session.http.get(file_url) as r:
181 with file_path.open("wb") as f:
182 f.write(await r.read())
184 elif data_stream is not None:
185 data = data_stream.read()
186 if data is None:
187 raise RuntimeError
189 with file_path.open("wb") as f:
190 f.write(data)
191 elif async_data_stream is not None:
192 # TODO: patch slixmpp to allow this as data source for
193 # upload_file() so we don't even have to write anything
194 # to disk.
195 with file_path.open("wb") as f:
196 async for chunk in async_data_stream:
197 f.write(chunk)
198 elif data is not None:
199 with file_path.open("wb") as f:
200 f.write(data)
202 is_temp = not bool(config.NO_UPLOAD_PATH)
203 else:
204 is_temp = False
206 if config.FIX_FILENAME_SUFFIX_MIME_TYPE:
207 file_name = str(fix_suffix(file_path, content_type, file_name))
209 if config.NO_UPLOAD_PATH:
210 local_path, new_url = await self.__no_upload(
211 file_path, file_name, legacy_file_id
212 )
213 new_url = (config.NO_UPLOAD_URL_PREFIX or "") + "/" + urlquote(new_url)
214 else:
215 local_path = file_path
216 new_url = await self.__upload(file_path, file_name, content_type)
217 if legacy_file_id:
218 self.__store.set_url(self.session.user_pk, str(legacy_file_id), new_url)
220 return is_temp, local_path, new_url
222 async def __set_sims(
223 self,
224 msg: Message,
225 uploaded_url: str,
226 path: Optional[Path],
227 content_type: Optional[str] = None,
228 caption: Optional[str] = None,
229 file_name: Optional[str] = None,
230 ) -> Thumbnail | None:
231 cache = self.__store.get_sims(uploaded_url)
232 if cache:
233 ref = self.xmpp["xep_0372"].stanza.Reference(xml=ET.fromstring(cache))
234 msg.append(ref)
235 if ref["sims"]["file"].get_plugin("thumbnail", check=True):
236 return ref["sims"]["file"]["thumbnail"]
237 else:
238 return None
240 if not path:
241 return None
243 ref = self.xmpp["xep_0385"].get_sims(
244 path, [uploaded_url], content_type, caption
245 )
246 if file_name:
247 ref["sims"]["file"]["name"] = file_name
248 thumbnail = None
249 if content_type is not None and content_type.startswith("image"):
250 try:
251 h, x, y = await self.xmpp.loop.run_in_executor(
252 avatar_cache._thread_pool, get_thumbhash, path
253 )
254 except Exception as e:
255 log.debug("Could not generate a thumbhash", exc_info=e)
256 else:
257 thumbnail = ref["sims"]["file"]["thumbnail"]
258 thumbnail["width"] = x
259 thumbnail["height"] = y
260 thumbnail["media-type"] = "image/thumbhash"
261 thumbnail["uri"] = "data:image/thumbhash;base64," + urlquote(h)
263 self.__store.set_sims(uploaded_url, str(ref))
265 msg.append(ref)
267 return thumbnail
269 def __set_sfs(
270 self,
271 msg: Message,
272 uploaded_url: str,
273 path: Optional[Path],
274 content_type: Optional[str] = None,
275 caption: Optional[str] = None,
276 file_name: Optional[str] = None,
277 thumbnail: Optional[Thumbnail] = None,
278 ):
279 cache = self.__store.get_sfs(uploaded_url)
280 if cache:
281 msg.append(StatelessFileSharing(xml=ET.fromstring(cache)))
282 return
284 if not path:
285 return
287 sfs = self.xmpp["xep_0447"].get_sfs(path, [uploaded_url], content_type, caption)
288 if file_name:
289 sfs["file"]["name"] = file_name
290 if thumbnail is not None:
291 sfs["file"].append(thumbnail)
292 self.__store.set_sfs(uploaded_url, str(sfs))
294 msg.append(sfs)
296 def __send_url(
297 self,
298 msg: Message,
299 legacy_msg_id: LegacyMessageType,
300 uploaded_url: str,
301 caption: Optional[str] = None,
302 carbon=False,
303 when: Optional[datetime] = None,
304 correction=False,
305 **kwargs,
306 ) -> list[Message]:
307 msg["oob"]["url"] = uploaded_url
308 msg["body"] = uploaded_url
309 if caption:
310 m1 = self._send(msg, carbon=carbon, **kwargs)
311 m2 = self.send_text(
312 caption,
313 legacy_msg_id=legacy_msg_id,
314 when=when,
315 carbon=carbon,
316 correction=correction,
317 **kwargs,
318 )
319 return [m1, m2] if m2 else [m1]
320 else:
321 if correction:
322 msg["replace"]["id"] = self._replace_id(legacy_msg_id)
323 else:
324 self._set_msg_id(msg, legacy_msg_id)
325 return [self._send(msg, carbon=carbon, **kwargs)]
327 async def send_file(
328 self,
329 file_path: Optional[Union[Path, str]] = None,
330 legacy_msg_id: Optional[LegacyMessageType] = None,
331 *,
332 async_data_stream: Optional[AsyncIterator[bytes]] = None,
333 data_stream: Optional[IO[bytes]] = None,
334 data: Optional[bytes] = None,
335 file_url: Optional[str] = None,
336 file_name: Optional[str] = None,
337 content_type: Optional[str] = None,
338 reply_to: Optional[MessageReference] = None,
339 when: Optional[datetime] = None,
340 caption: Optional[str] = None,
341 legacy_file_id: Optional[Union[str, int]] = None,
342 thread: Optional[LegacyThreadType] = None,
343 **kwargs,
344 ) -> tuple[Optional[str], list[Message]]:
345 """
346 Send a single file from this :term:`XMPP Entity`.
348 :param file_path: Path to the attachment
349 :param async_data_stream: Alternatively (and ideally) an AsyncIterator yielding bytes
350 :param data_stream: Alternatively, a stream of bytes (such as a File object)
351 :param data: Alternatively, a bytes object
352 :param file_url: Alternatively, a URL
353 :param file_name: How the file should be named.
354 :param content_type: MIME type, inferred from filename if not given
355 :param legacy_msg_id: If you want to be able to transport read markers from the gateway
356 user to the legacy network, specify this
357 :param reply_to: Quote another message (:xep:`0461`)
358 :param when: when the file was sent, for a "delay" tag (:xep:`0203`)
359 :param caption: an optional text that is linked to the file
360 :param legacy_file_id: A unique identifier for the file on the legacy network.
361 Plugins should try their best to provide it, to avoid duplicates.
362 :param thread:
363 """
364 carbon = kwargs.pop("carbon", False)
365 mto = kwargs.pop("mto", None)
366 store_multi = kwargs.pop("store_multi", True)
367 correction = kwargs.get("correction", False)
368 if correction and (original_xmpp_id := self._legacy_to_xmpp(legacy_msg_id)):
369 xmpp_ids = self.xmpp.store.multi.get_xmpp_ids(
370 self.session.user_pk, original_xmpp_id
371 )
373 for xmpp_id in xmpp_ids:
374 if xmpp_id == original_xmpp_id:
375 continue
376 self.retract(xmpp_id, thread)
378 if reply_to is not None and reply_to.body:
379 # We cannot have a "quote fallback" for attachments since most (all?)
380 # XMPP clients will only treat a message as an attachment if the
381 # body is the URL and nothing else.
382 reply_to_for_attachment: MessageReference | None = MessageReference(
383 reply_to.legacy_id, reply_to.author
384 )
385 else:
386 reply_to_for_attachment = reply_to
388 msg = self._make_message(
389 when=when,
390 reply_to=reply_to_for_attachment,
391 carbon=carbon,
392 mto=mto,
393 thread=thread,
394 )
396 if content_type is None and (name := (file_name or file_path or file_url)):
397 content_type, _ = guess_type(name)
399 is_temp, local_path, new_url = await self.__get_url(
400 Path(file_path) if file_path else None,
401 async_data_stream,
402 data_stream,
403 data,
404 file_url,
405 file_name,
406 content_type,
407 legacy_file_id,
408 )
410 if new_url is None:
411 msg["body"] = (
412 "I tried to send a file, but something went wrong. "
413 "Tell your slidge admin to check the logs."
414 )
415 self._set_msg_id(msg, legacy_msg_id)
416 return None, [self._send(msg, **kwargs)]
418 thumbnail = await self.__set_sims(
419 msg, new_url, local_path, content_type, caption, file_name
420 )
421 self.__set_sfs(
422 msg, new_url, local_path, content_type, caption, file_name, thumbnail
423 )
424 if is_temp and isinstance(local_path, Path):
425 local_path.unlink()
426 local_path.parent.rmdir()
428 msgs = self.__send_url(
429 msg, legacy_msg_id, new_url, caption, carbon, when, **kwargs
430 )
431 if store_multi:
432 self.__store_multi(legacy_msg_id, msgs)
433 return new_url, msgs
435 def __send_body(
436 self,
437 body: Optional[str] = None,
438 legacy_msg_id: Optional[LegacyMessageType] = None,
439 reply_to: Optional[MessageReference] = None,
440 when: Optional[datetime] = None,
441 thread: Optional[LegacyThreadType] = None,
442 **kwargs,
443 ) -> Optional[Message]:
444 if body:
445 return self.send_text(
446 body,
447 legacy_msg_id,
448 reply_to=reply_to,
449 when=when,
450 thread=thread,
451 **kwargs,
452 )
453 else:
454 return None
456 async def send_files(
457 self,
458 attachments: Collection[LegacyAttachment],
459 legacy_msg_id: Optional[LegacyMessageType] = None,
460 body: Optional[str] = None,
461 *,
462 reply_to: Optional[MessageReference] = None,
463 when: Optional[datetime] = None,
464 thread: Optional[LegacyThreadType] = None,
465 body_first=False,
466 correction=False,
467 correction_event_id: Optional[LegacyMessageType] = None,
468 **kwargs,
469 ):
470 # TODO: once the epic XEP-0385 vs XEP-0447 battle is over, pick
471 # one and stop sending several attachments this way
472 # we attach the legacy_message ID to the last message we send, because
473 # we don't want several messages with the same ID (especially for MUC MAM)
474 # TODO: refactor this so we limit the number of SQL calls, ie, if
475 # the legacy file ID is known, only fetch the row once, and if it
476 # is new, write it all in a single call
477 if not attachments and not body:
478 # ignoring empty message
479 return
480 send_body = functools.partial(
481 self.__send_body,
482 body=body,
483 reply_to=reply_to,
484 when=when,
485 thread=thread,
486 correction=correction,
487 legacy_msg_id=legacy_msg_id,
488 correction_event_id=correction_event_id,
489 **kwargs,
490 )
491 all_msgs = []
492 if body_first:
493 all_msgs.append(send_body())
494 last_attachment_i = len(attachments) - 1
495 for i, attachment in enumerate(attachments):
496 last = i == last_attachment_i
497 if last and not body:
498 legacy = legacy_msg_id
499 else:
500 legacy = None
501 _url, msgs = await self.send_file(
502 file_path=attachment.path,
503 legacy_msg_id=legacy,
504 file_url=attachment.url,
505 data_stream=attachment.stream,
506 data=attachment.data,
507 reply_to=reply_to,
508 when=when,
509 thread=thread,
510 file_name=attachment.name,
511 content_type=attachment.content_type,
512 legacy_file_id=attachment.legacy_file_id,
513 caption=attachment.caption,
514 store_multi=False,
515 **kwargs,
516 )
517 all_msgs.extend(msgs)
518 if not body_first:
519 all_msgs.append(send_body())
520 self.__store_multi(legacy_msg_id, all_msgs)
522 def __store_multi(
523 self,
524 legacy_msg_id: Optional[LegacyMessageType],
525 all_msgs: Sequence[Optional[Message]],
526 ):
527 if legacy_msg_id is None:
528 return
529 ids = []
530 for msg in all_msgs:
531 if not msg:
532 continue
533 if stanza_id := msg.get_plugin("stanza_id", check=True):
534 ids.append(stanza_id["id"])
535 else:
536 ids.append(msg.get_id())
537 self.xmpp.store.multi.set_xmpp_ids(
538 self.session.user_pk, str(legacy_msg_id), ids
539 )
542def get_thumbhash(path: Path) -> tuple[str, int, int]:
543 with path.open("rb") as fp:
544 img = Image.open(fp)
545 width, height = img.size
546 img = img.convert("RGBA")
547 if width > 100 or height > 100:
548 img.thumbnail((100, 100))
549 img = ImageOps.exif_transpose(img)
550 rgba_2d = list(img.getdata())
551 rgba = list(chain(*rgba_2d))
552 ints = thumbhash.rgba_to_thumb_hash(img.width, img.height, rgba)
553 return base64.b64encode(bytes(ints)).decode(), width, height
556log = logging.getLogger(__name__)