Coverage for slidge / db / avatar.py: 91%
169 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1import asyncio
2import hashlib
3import io
4import logging
5from concurrent.futures import ThreadPoolExecutor
6from http import HTTPStatus
7from pathlib import Path
9import aiohttp
10from multidict import CIMultiDictProxy
11from PIL.Image import Image
12from PIL.Image import open as open_image
13from sqlalchemy import select
15from ..core import config
16from ..util.lock import NamedLockMixin
17from ..util.types import Avatar as AvatarType
18from .models import Avatar
19from .store import AvatarStore
22class CachedAvatar:
23 def __init__(self, stored: Avatar, root_dir: Path) -> None:
24 self.stored = stored
25 self._root = root_dir
27 @property
28 def pk(self) -> int | None:
29 return self.stored.id
31 @property
32 def hash(self) -> str:
33 return self.stored.hash
35 @property
36 def height(self) -> int:
37 return self.stored.height
39 @property
40 def width(self) -> int:
41 return self.stored.width
43 @property
44 def etag(self) -> str | None:
45 return self.stored.etag
47 @property
48 def last_modified(self) -> str | None:
49 return self.stored.last_modified
51 @property
52 def data(self) -> bytes:
53 return self.path.read_bytes()
55 @property
56 def path(self) -> Path:
57 return (self._root / self.hash).with_suffix(".png")
60class NotModified(Exception):
61 pass
64class AvatarCache(NamedLockMixin):
65 dir: Path
66 http: aiohttp.ClientSession
67 store: AvatarStore
69 def __init__(self) -> None:
70 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS)
71 super().__init__()
73 def get(self, stored: Avatar) -> CachedAvatar:
74 return CachedAvatar(stored, self.dir)
76 def set_dir(self, path: Path) -> None:
77 self.dir = path
78 self.dir.mkdir(exist_ok=True)
79 log.debug("Checking avatar files")
80 with self.store.session(expire_on_commit=False) as orm:
81 for stored in orm.query(Avatar).all():
82 avatar = CachedAvatar(stored, path)
83 if avatar.path.exists():
84 continue
85 log.warning(
86 "Removing avatar %s from store because %s does not exist",
87 avatar.hash,
88 avatar.path,
89 )
90 orm.delete(stored)
91 orm.commit()
93 def close(self) -> None:
94 self._thread_pool.shutdown(cancel_futures=True)
96 def __get_http_headers(
97 self, cached: CachedAvatar | Avatar | None = None
98 ) -> dict[str, str]:
99 headers = {}
100 if cached and (self.dir / cached.hash).with_suffix(".png").exists():
101 if last_modified := cached.last_modified:
102 headers["If-Modified-Since"] = last_modified
103 if etag := cached.etag:
104 headers["If-None-Match"] = etag
105 return headers
107 async def __download(
108 self,
109 url: str,
110 headers: dict[str, str],
111 ) -> tuple[Image, CIMultiDictProxy[str]]:
112 async with self.http.get(url, headers=headers) as response:
113 if response.status == HTTPStatus.NOT_MODIFIED:
114 log.debug("Using avatar cache for %s", url)
115 raise NotModified
116 response.raise_for_status()
117 return (
118 open_image(io.BytesIO(await response.read())),
119 response.headers,
120 )
122 async def __is_modified(self, url: str, headers: dict[str, str]) -> bool:
123 async with self.http.head(url, headers=headers) as response:
124 return response.status != HTTPStatus.NOT_MODIFIED
126 async def url_modified(self, url: str) -> bool:
127 with self.store.session() as orm:
128 cached = orm.query(Avatar).filter_by(url=url).one_or_none()
129 if cached is None:
130 return True
131 headers = self.__get_http_headers(cached)
132 return await self.__is_modified(url, headers)
134 @staticmethod
135 async def _get_image(avatar: AvatarType) -> Image:
136 if avatar.data is not None:
137 return open_image(io.BytesIO(avatar.data))
138 elif avatar.path is not None:
139 return open_image(avatar.path)
140 raise TypeError("Avatar must be bytes or a Path", avatar)
142 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar:
143 if avatar.unique_id is not None:
144 with self.store.session() as orm:
145 stored = (
146 orm.query(Avatar)
147 .filter_by(legacy_id=str(avatar.unique_id))
148 .one_or_none()
149 )
150 if stored is not None:
151 return self.get(stored)
153 if avatar.url is not None:
154 return await self.__convert_url(avatar)
156 return await self.convert(avatar, await self._get_image(avatar))
158 async def __convert_url(self, avatar: AvatarType) -> CachedAvatar:
159 assert avatar.url is not None
160 async with self.lock(avatar.unique_id or avatar.url):
161 with self.store.session() as orm:
162 if avatar.unique_id is None:
163 stored = orm.query(Avatar).filter_by(url=avatar.url).one_or_none()
164 else:
165 stored = (
166 orm.query(Avatar)
167 .filter_by(legacy_id=str(avatar.unique_id))
168 .one_or_none()
169 )
170 if stored is not None:
171 return self.get(stored)
173 try:
174 img, response_headers = await self.__download(
175 avatar.url, self.__get_http_headers(stored)
176 )
177 except NotModified:
178 assert stored is not None
179 return self.get(stored)
181 return await self.convert(avatar, img, response_headers)
183 async def convert(
184 self,
185 avatar: AvatarType,
186 img: Image,
187 response_headers: CIMultiDictProxy[str] | None = None,
188 ) -> CachedAvatar:
189 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size)
190 if resize:
191 await asyncio.get_event_loop().run_in_executor(
192 self._thread_pool, img.thumbnail, (size, size)
193 )
194 log.debug("Resampled image to %s", img.size)
196 if (
197 not resize
198 and img.format == "PNG"
199 and avatar.path is not None
200 and avatar.path.exists()
201 ):
202 img_bytes = avatar.path.read_bytes()
203 else:
204 with io.BytesIO() as f:
205 img.save(f, format="PNG")
206 img_bytes = f.getvalue()
208 hash_ = hashlib.sha1(img_bytes).hexdigest()
209 file_path = (self.dir / hash_).with_suffix(".png")
210 if file_path.exists():
211 log.warning("Overwriting %s", file_path)
212 with file_path.open("wb") as file:
213 file.write(img_bytes)
215 with self.store.session(expire_on_commit=False) as orm:
216 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar()
218 if stored is not None:
219 if avatar.unique_id is not None:
220 if str(avatar.unique_id) != stored.legacy_id:
221 log.warning(
222 "Updating the 'unique' ID of an avatar, was '%s', is now '%s'",
223 stored.legacy_id,
224 avatar.unique_id,
225 )
226 stored.legacy_id = str(avatar.unique_id)
227 orm.add(stored)
228 orm.commit()
230 return self.get(stored)
232 stored = Avatar(
233 hash=hash_,
234 height=img.height,
235 width=img.width,
236 url=avatar.url,
237 legacy_id=avatar.unique_id,
238 )
239 if response_headers:
240 stored.etag = response_headers.get("etag")
241 stored.last_modified = response_headers.get("last-modified")
243 with self.store.session(expire_on_commit=False) as orm:
244 if avatar.url is not None:
245 existing = orm.execute(
246 select(Avatar).filter_by(url=avatar.url)
247 ).scalar_one_or_none()
248 if existing is not None:
249 orm.delete(existing)
250 orm.commit()
251 orm.add(stored)
252 orm.commit()
253 return self.get(stored)
256avatar_cache = AvatarCache()
257log = logging.getLogger(__name__)
258_download_lock = asyncio.Lock()
260__all__ = (
261 "CachedAvatar",
262 "avatar_cache",
263)