Coverage for slidge/db/avatar.py: 92%
132 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1import asyncio
2import hashlib
3import io
4import logging
5import uuid
6from concurrent.futures import ThreadPoolExecutor
7from dataclasses import dataclass
8from http import HTTPStatus
9from pathlib import Path
10from typing import Optional
12import aiohttp
13from multidict import CIMultiDictProxy
14from PIL.Image import Image
15from PIL.Image import open as open_image
16from sqlalchemy import select
18from slidge.core import config
19from slidge.db.models import Avatar
20from slidge.db.store import AvatarStore
21from slidge.util.types import URL, AvatarType
24@dataclass
25class CachedAvatar:
26 pk: int
27 filename: str
28 hash: str
29 height: int
30 width: int
31 root: Path
32 etag: Optional[str] = None
33 last_modified: Optional[str] = None
35 @property
36 def data(self):
37 return self.path.read_bytes()
39 @property
40 def path(self):
41 return self.root / self.filename
43 @staticmethod
44 def from_store(stored: Avatar, root_dir: Path):
45 return CachedAvatar(
46 pk=stored.id,
47 filename=stored.filename,
48 hash=stored.hash,
49 height=stored.height,
50 width=stored.width,
51 etag=stored.etag,
52 root=root_dir,
53 last_modified=stored.last_modified,
54 )
57class NotModified(Exception):
58 pass
61class AvatarCache:
62 dir: Path
63 http: aiohttp.ClientSession
64 store: AvatarStore
66 def __init__(self):
67 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS)
69 def set_dir(self, path: Path):
70 self.dir = path
71 self.dir.mkdir(exist_ok=True)
72 with self.store.session():
73 for stored in self.store.get_all():
74 avatar = CachedAvatar.from_store(stored, root_dir=path)
75 if avatar.path.exists():
76 continue
77 log.warning(
78 "Removing avatar %s from store because %s does not exist",
79 avatar.hash,
80 avatar.path,
81 )
82 self.store.delete_by_pk(stored.id)
84 def close(self):
85 self._thread_pool.shutdown(cancel_futures=True)
87 def __get_http_headers(self, cached: Optional[CachedAvatar | Avatar]):
88 headers = {}
89 if cached and (self.dir / cached.filename).exists():
90 if last_modified := cached.last_modified:
91 headers["If-Modified-Since"] = last_modified
92 if etag := cached.etag:
93 headers["If-None-Match"] = etag
94 return headers
96 async def __download(
97 self,
98 url: str,
99 headers: dict[str, str],
100 ) -> tuple[Image, CIMultiDictProxy[str]]:
101 async with self.http.get(url, headers=headers) as response:
102 if response.status == HTTPStatus.NOT_MODIFIED:
103 log.debug("Using avatar cache for %s", url)
104 raise NotModified
105 return (
106 open_image(io.BytesIO(await response.read())),
107 response.headers,
108 )
110 async def __is_modified(self, url, headers) -> bool:
111 async with self.http.head(url, headers=headers) as response:
112 return response.status != HTTPStatus.NOT_MODIFIED
114 async def url_modified(self, url: URL) -> bool:
115 cached = self.store.get_by_url(url)
116 if cached is None:
117 return True
118 headers = self.__get_http_headers(cached)
119 return await self.__is_modified(url, headers)
121 def get_by_pk(self, pk: int) -> CachedAvatar:
122 stored = self.store.get_by_pk(pk)
123 assert stored is not None
124 return CachedAvatar.from_store(stored, self.dir)
126 @staticmethod
127 async def _get_image(avatar: AvatarType) -> Image:
128 if isinstance(avatar, bytes):
129 return open_image(io.BytesIO(avatar))
130 elif isinstance(avatar, Path):
131 return open_image(avatar)
132 raise TypeError("Avatar must be bytes or a Path", avatar)
134 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar:
135 if isinstance(avatar, (URL, str)):
136 with self.store.session():
137 stored = self.store.get_by_url(avatar)
138 try:
139 img, response_headers = await self.__download(
140 avatar, self.__get_http_headers(stored)
141 )
142 except NotModified:
143 assert stored is not None
144 return CachedAvatar.from_store(stored, self.dir)
145 else:
146 img = await self._get_image(avatar)
147 response_headers = None
148 with self.store.session() as orm:
149 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size)
150 if resize:
151 await asyncio.get_event_loop().run_in_executor(
152 self._thread_pool, img.thumbnail, (size, size)
153 )
154 log.debug("Resampled image to %s", img.size)
156 filename = str(uuid.uuid1()) + ".png"
157 file_path = self.dir / filename
159 if (
160 not resize
161 and img.format == "PNG"
162 and isinstance(avatar, (str, Path))
163 and (path := Path(avatar))
164 and path.exists()
165 ):
166 img_bytes = path.read_bytes()
167 else:
168 with io.BytesIO() as f:
169 img.save(f, format="PNG")
170 img_bytes = f.getvalue()
172 with file_path.open("wb") as file:
173 file.write(img_bytes)
175 hash_ = hashlib.sha1(img_bytes).hexdigest()
177 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar()
179 if stored is not None:
180 return CachedAvatar.from_store(stored, self.dir)
182 stored = Avatar(
183 filename=filename,
184 hash=hash_,
185 height=img.height,
186 width=img.width,
187 url=avatar if isinstance(avatar, (URL, str)) else None,
188 )
189 if response_headers:
190 stored.etag = response_headers.get("etag")
191 stored.last_modified = response_headers.get("last-modified")
193 orm.add(stored)
194 orm.commit()
195 return CachedAvatar.from_store(stored, self.dir)
198avatar_cache = AvatarCache()
199log = logging.getLogger(__name__)
200_download_lock = asyncio.Lock()
202__all__ = (
203 "CachedAvatar",
204 "avatar_cache",
205)