Coverage for slidge / db / avatar.py: 91%
169 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-02-15 09:02 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-02-15 09:02 +0000
1import asyncio
2import hashlib
3import io
4import logging
5from concurrent.futures import ThreadPoolExecutor
6from http import HTTPStatus
7from pathlib import Path
9import aiohttp
10from multidict import CIMultiDictProxy
11from PIL.Image import Image
12from PIL.Image import open as open_image
13from sqlalchemy import select
15from ..core import config
16from ..util.lock import NamedLockMixin
17from ..util.types import Avatar as AvatarType
18from .models import Avatar
19from .store import AvatarStore
22class CachedAvatar:
23 def __init__(self, stored: Avatar, root_dir: Path) -> None:
24 self.stored = stored
25 self._root = root_dir
27 @property
28 def pk(self) -> int | None:
29 return self.stored.id
31 @property
32 def hash(self) -> str:
33 return self.stored.hash
35 @property
36 def height(self) -> int:
37 return self.stored.height
39 @property
40 def width(self) -> int:
41 return self.stored.width
43 @property
44 def etag(self) -> str | None:
45 return self.stored.etag
47 @property
48 def last_modified(self) -> str | None:
49 return self.stored.last_modified
51 @property
52 def data(self):
53 return self.path.read_bytes()
55 @property
56 def path(self):
57 return (self._root / self.hash).with_suffix(".png")
60class NotModified(Exception):
61 pass
64class AvatarCache(NamedLockMixin):
65 dir: Path
66 http: aiohttp.ClientSession
67 store: AvatarStore
69 def __init__(self) -> None:
70 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS)
71 super().__init__()
73 def get(self, stored: Avatar) -> CachedAvatar:
74 return CachedAvatar(stored, self.dir)
76 def set_dir(self, path: Path) -> None:
77 self.dir = path
78 self.dir.mkdir(exist_ok=True)
79 log.debug("Checking avatar files")
80 with self.store.session(expire_on_commit=False) as orm:
81 for stored in orm.query(Avatar).all():
82 avatar = CachedAvatar(stored, path)
83 if avatar.path.exists():
84 continue
85 log.warning(
86 "Removing avatar %s from store because %s does not exist",
87 avatar.hash,
88 avatar.path,
89 )
90 orm.delete(stored)
91 orm.commit()
93 def close(self) -> None:
94 self._thread_pool.shutdown(cancel_futures=True)
96 def __get_http_headers(self, cached: CachedAvatar | Avatar | None = None):
97 headers = {}
98 if cached and (self.dir / cached.hash).with_suffix(".png").exists():
99 if last_modified := cached.last_modified:
100 headers["If-Modified-Since"] = last_modified
101 if etag := cached.etag:
102 headers["If-None-Match"] = etag
103 return headers
105 async def __download(
106 self,
107 url: str,
108 headers: dict[str, str],
109 ) -> tuple[Image, CIMultiDictProxy[str]]:
110 async with self.http.get(url, headers=headers) as response:
111 if response.status == HTTPStatus.NOT_MODIFIED:
112 log.debug("Using avatar cache for %s", url)
113 raise NotModified
114 response.raise_for_status()
115 return (
116 open_image(io.BytesIO(await response.read())),
117 response.headers,
118 )
120 async def __is_modified(self, url, headers) -> bool:
121 async with self.http.head(url, headers=headers) as response:
122 return response.status != HTTPStatus.NOT_MODIFIED
124 async def url_modified(self, url: str) -> bool:
125 with self.store.session() as orm:
126 cached = orm.query(Avatar).filter_by(url=url).one_or_none()
127 if cached is None:
128 return True
129 headers = self.__get_http_headers(cached)
130 return await self.__is_modified(url, headers)
132 @staticmethod
133 async def _get_image(avatar: AvatarType) -> Image:
134 if avatar.data is not None:
135 return open_image(io.BytesIO(avatar.data))
136 elif avatar.path is not None:
137 return open_image(avatar.path)
138 raise TypeError("Avatar must be bytes or a Path", avatar)
140 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar:
141 if avatar.unique_id is not None:
142 with self.store.session() as orm:
143 stored = (
144 orm.query(Avatar)
145 .filter_by(legacy_id=str(avatar.unique_id))
146 .one_or_none()
147 )
148 if stored is not None:
149 return self.get(stored)
151 if avatar.url is not None:
152 return await self.__convert_url(avatar)
154 return await self.convert(avatar, await self._get_image(avatar))
156 async def __convert_url(self, avatar: AvatarType) -> CachedAvatar:
157 assert avatar.url is not None
158 async with self.lock(avatar.unique_id or avatar.url):
159 with self.store.session() as orm:
160 if avatar.unique_id is None:
161 stored = orm.query(Avatar).filter_by(url=avatar.url).one_or_none()
162 else:
163 stored = (
164 orm.query(Avatar)
165 .filter_by(legacy_id=str(avatar.unique_id))
166 .one_or_none()
167 )
168 if stored is not None:
169 return self.get(stored)
171 try:
172 img, response_headers = await self.__download(
173 avatar.url, self.__get_http_headers(stored)
174 )
175 except NotModified:
176 assert stored is not None
177 return self.get(stored)
179 return await self.convert(avatar, img, response_headers)
181 async def convert(
182 self,
183 avatar: AvatarType,
184 img: Image,
185 response_headers: CIMultiDictProxy[str] | None = None,
186 ) -> CachedAvatar:
187 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size)
188 if resize:
189 await asyncio.get_event_loop().run_in_executor(
190 self._thread_pool, img.thumbnail, (size, size)
191 )
192 log.debug("Resampled image to %s", img.size)
194 if (
195 not resize
196 and img.format == "PNG"
197 and avatar.path is not None
198 and avatar.path.exists()
199 ):
200 img_bytes = avatar.path.read_bytes()
201 else:
202 with io.BytesIO() as f:
203 img.save(f, format="PNG")
204 img_bytes = f.getvalue()
206 hash_ = hashlib.sha1(img_bytes).hexdigest()
207 file_path = (self.dir / hash_).with_suffix(".png")
208 if file_path.exists():
209 log.warning("Overwriting %s", file_path)
210 with file_path.open("wb") as file:
211 file.write(img_bytes)
213 with self.store.session(expire_on_commit=False) as orm:
214 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar()
216 if stored is not None:
217 if avatar.unique_id is not None:
218 if str(avatar.unique_id) != stored.legacy_id:
219 log.warning(
220 "Updating the 'unique' ID of an avatar, was '%s', is now '%s'",
221 stored.legacy_id,
222 avatar.unique_id,
223 )
224 stored.legacy_id = str(avatar.unique_id)
225 orm.add(stored)
226 orm.commit()
228 return self.get(stored)
230 stored = Avatar(
231 hash=hash_,
232 height=img.height,
233 width=img.width,
234 url=avatar.url,
235 legacy_id=avatar.unique_id,
236 )
237 if response_headers:
238 stored.etag = response_headers.get("etag")
239 stored.last_modified = response_headers.get("last-modified")
241 with self.store.session(expire_on_commit=False) as orm:
242 if avatar.url is not None:
243 existing = orm.execute(
244 select(Avatar).filter_by(url=avatar.url)
245 ).scalar_one_or_none()
246 if existing is not None:
247 orm.delete(existing)
248 orm.commit()
249 orm.add(stored)
250 orm.commit()
251 return self.get(stored)
254avatar_cache = AvatarCache()
255log = logging.getLogger(__name__)
256_download_lock = asyncio.Lock()
258__all__ = (
259 "CachedAvatar",
260 "avatar_cache",
261)