Coverage for slidge/db/avatar.py: 90%
165 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1import asyncio
2import hashlib
3import io
4import logging
5from concurrent.futures import ThreadPoolExecutor
6from http import HTTPStatus
7from pathlib import Path
8from typing import Optional
10import aiohttp
11from multidict import CIMultiDictProxy
12from PIL.Image import Image
13from PIL.Image import open as open_image
14from sqlalchemy import select
16from ..core import config
17from ..util.lock import NamedLockMixin
18from ..util.types import Avatar as AvatarType
19from .models import Avatar
20from .store import AvatarStore
23class CachedAvatar:
24 def __init__(self, stored: Avatar, root_dir: Path) -> None:
25 self.stored = stored
26 self._root = root_dir
28 @property
29 def pk(self) -> int | None:
30 return self.stored.id
32 @property
33 def hash(self) -> str:
34 return self.stored.hash
36 @property
37 def height(self) -> int:
38 return self.stored.height
40 @property
41 def width(self) -> int:
42 return self.stored.width
44 @property
45 def etag(self) -> str | None:
46 return self.stored.etag
48 @property
49 def last_modified(self) -> str | None:
50 return self.stored.last_modified
52 @property
53 def data(self):
54 return self.path.read_bytes()
56 @property
57 def path(self):
58 return (self._root / self.hash).with_suffix(".png")
61class NotModified(Exception):
62 pass
65class AvatarCache(NamedLockMixin):
66 dir: Path
67 http: aiohttp.ClientSession
68 store: AvatarStore
70 def __init__(self) -> None:
71 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS)
72 super().__init__()
74 def get(self, stored: Avatar) -> CachedAvatar:
75 return CachedAvatar(stored, self.dir)
77 def set_dir(self, path: Path) -> None:
78 self.dir = path
79 self.dir.mkdir(exist_ok=True)
80 log.debug("Checking avatar files")
81 with self.store.session(expire_on_commit=False) as orm:
82 for stored in orm.query(Avatar).all():
83 avatar = CachedAvatar(stored, path)
84 if avatar.path.exists():
85 continue
86 log.warning(
87 "Removing avatar %s from store because %s does not exist",
88 avatar.hash,
89 avatar.path,
90 )
91 orm.delete(stored)
92 orm.commit()
94 def close(self) -> None:
95 self._thread_pool.shutdown(cancel_futures=True)
97 def __get_http_headers(self, cached: Optional[CachedAvatar | Avatar] = None):
98 headers = {}
99 if cached and (self.dir / cached.hash).with_suffix(".png").exists():
100 if last_modified := cached.last_modified:
101 headers["If-Modified-Since"] = last_modified
102 if etag := cached.etag:
103 headers["If-None-Match"] = etag
104 return headers
106 async def __download(
107 self,
108 url: str,
109 headers: dict[str, str],
110 ) -> tuple[Image, CIMultiDictProxy[str]]:
111 async with self.http.get(url, headers=headers) as response:
112 if response.status == HTTPStatus.NOT_MODIFIED:
113 log.debug("Using avatar cache for %s", url)
114 raise NotModified
115 response.raise_for_status()
116 return (
117 open_image(io.BytesIO(await response.read())),
118 response.headers,
119 )
121 async def __is_modified(self, url, headers) -> bool:
122 async with self.http.head(url, headers=headers) as response:
123 return response.status != HTTPStatus.NOT_MODIFIED
125 async def url_modified(self, url: str) -> bool:
126 with self.store.session() as orm:
127 cached = orm.query(Avatar).filter_by(url=url).one_or_none()
128 if cached is None:
129 return True
130 headers = self.__get_http_headers(cached)
131 return await self.__is_modified(url, headers)
133 @staticmethod
134 async def _get_image(avatar: AvatarType) -> Image:
135 if avatar.data is not None:
136 return open_image(io.BytesIO(avatar.data))
137 elif avatar.path is not None:
138 return open_image(avatar.path)
139 raise TypeError("Avatar must be bytes or a Path", avatar)
141 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar:
142 if avatar.unique_id is not None:
143 with self.store.session() as orm:
144 stored = (
145 orm.query(Avatar)
146 .filter_by(legacy_id=str(avatar.unique_id))
147 .one_or_none()
148 )
149 if stored is not None:
150 return self.get(stored)
152 if avatar.url is not None:
153 return await self.__convert_url(avatar)
155 return await self.convert(avatar, await self._get_image(avatar))
157 async def __convert_url(self, avatar: AvatarType) -> CachedAvatar:
158 assert avatar.url is not None
159 async with self.lock(avatar.unique_id or avatar.url):
160 with self.store.session() as orm:
161 if avatar.unique_id is None:
162 stored = orm.query(Avatar).filter_by(url=avatar.url).one_or_none()
163 else:
164 stored = (
165 orm.query(Avatar)
166 .filter_by(legacy_id=str(avatar.unique_id))
167 .one_or_none()
168 )
169 if stored is not None:
170 return self.get(stored)
172 try:
173 img, response_headers = await self.__download(
174 avatar.url, self.__get_http_headers(stored)
175 )
176 except NotModified:
177 assert stored is not None
178 return self.get(stored)
180 return await self.convert(avatar, img, response_headers)
182 async def convert(
183 self,
184 avatar: AvatarType,
185 img: Image,
186 response_headers: CIMultiDictProxy[str] | None = None,
187 ) -> CachedAvatar:
188 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size)
189 if resize:
190 await asyncio.get_event_loop().run_in_executor(
191 self._thread_pool, img.thumbnail, (size, size)
192 )
193 log.debug("Resampled image to %s", img.size)
195 if (
196 not resize
197 and img.format == "PNG"
198 and avatar.path is not None
199 and avatar.path.exists()
200 ):
201 img_bytes = avatar.path.read_bytes()
202 else:
203 with io.BytesIO() as f:
204 img.save(f, format="PNG")
205 img_bytes = f.getvalue()
207 hash_ = hashlib.sha1(img_bytes).hexdigest()
208 file_path = (self.dir / hash_).with_suffix(".png")
209 if file_path.exists():
210 log.warning("Overwriting %s", file_path)
211 with file_path.open("wb") as file:
212 file.write(img_bytes)
214 with self.store.session(expire_on_commit=False) as orm:
215 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar()
217 if stored is not None:
218 if avatar.unique_id is not None:
219 if str(avatar.unique_id) != stored.legacy_id:
220 log.warning(
221 "Updating the 'unique' ID of an avatar, was '%s', is now '%s'",
222 stored.legacy_id,
223 avatar.unique_id,
224 )
225 stored.legacy_id = str(avatar.unique_id)
226 orm.add(stored)
227 orm.commit()
229 return self.get(stored)
231 stored = Avatar(
232 hash=hash_,
233 height=img.height,
234 width=img.width,
235 url=avatar.url,
236 legacy_id=avatar.unique_id,
237 )
238 if response_headers:
239 stored.etag = response_headers.get("etag")
240 stored.last_modified = response_headers.get("last-modified")
242 with self.store.session(expire_on_commit=False) as orm:
243 orm.add(stored)
244 orm.commit()
245 return self.get(stored)
248avatar_cache = AvatarCache()
249log = logging.getLogger(__name__)
250_download_lock = asyncio.Lock()
252__all__ = (
253 "CachedAvatar",
254 "avatar_cache",
255)