Coverage for slidge / db / avatar.py: 91%

169 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1import asyncio 

2import hashlib 

3import io 

4import logging 

5from concurrent.futures import ThreadPoolExecutor 

6from http import HTTPStatus 

7from pathlib import Path 

8 

9import aiohttp 

10from multidict import CIMultiDictProxy 

11from PIL.Image import Image 

12from PIL.Image import open as open_image 

13from sqlalchemy import select 

14 

15from ..core import config 

16from ..util.lock import NamedLockMixin 

17from ..util.types import Avatar as AvatarType 

18from .models import Avatar 

19from .store import AvatarStore 

20 

21 

22class CachedAvatar: 

23 def __init__(self, stored: Avatar, root_dir: Path) -> None: 

24 self.stored = stored 

25 self._root = root_dir 

26 

27 @property 

28 def pk(self) -> int | None: 

29 return self.stored.id 

30 

31 @property 

32 def hash(self) -> str: 

33 return self.stored.hash 

34 

35 @property 

36 def height(self) -> int: 

37 return self.stored.height 

38 

39 @property 

40 def width(self) -> int: 

41 return self.stored.width 

42 

43 @property 

44 def etag(self) -> str | None: 

45 return self.stored.etag 

46 

47 @property 

48 def last_modified(self) -> str | None: 

49 return self.stored.last_modified 

50 

51 @property 

52 def data(self) -> bytes: 

53 return self.path.read_bytes() 

54 

55 @property 

56 def path(self) -> Path: 

57 return (self._root / self.hash).with_suffix(".png") 

58 

59 

60class NotModified(Exception): 

61 pass 

62 

63 

64class AvatarCache(NamedLockMixin): 

65 dir: Path 

66 http: aiohttp.ClientSession 

67 store: AvatarStore 

68 

69 def __init__(self) -> None: 

70 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS) 

71 super().__init__() 

72 

73 def get(self, stored: Avatar) -> CachedAvatar: 

74 return CachedAvatar(stored, self.dir) 

75 

76 def set_dir(self, path: Path) -> None: 

77 self.dir = path 

78 self.dir.mkdir(exist_ok=True) 

79 log.debug("Checking avatar files") 

80 with self.store.session(expire_on_commit=False) as orm: 

81 for stored in orm.query(Avatar).all(): 

82 avatar = CachedAvatar(stored, path) 

83 if avatar.path.exists(): 

84 continue 

85 log.warning( 

86 "Removing avatar %s from store because %s does not exist", 

87 avatar.hash, 

88 avatar.path, 

89 ) 

90 orm.delete(stored) 

91 orm.commit() 

92 

93 def close(self) -> None: 

94 self._thread_pool.shutdown(cancel_futures=True) 

95 

96 def __get_http_headers( 

97 self, cached: CachedAvatar | Avatar | None = None 

98 ) -> dict[str, str]: 

99 headers = {} 

100 if cached and (self.dir / cached.hash).with_suffix(".png").exists(): 

101 if last_modified := cached.last_modified: 

102 headers["If-Modified-Since"] = last_modified 

103 if etag := cached.etag: 

104 headers["If-None-Match"] = etag 

105 return headers 

106 

107 async def __download( 

108 self, 

109 url: str, 

110 headers: dict[str, str], 

111 ) -> tuple[Image, CIMultiDictProxy[str]]: 

112 async with self.http.get(url, headers=headers) as response: 

113 if response.status == HTTPStatus.NOT_MODIFIED: 

114 log.debug("Using avatar cache for %s", url) 

115 raise NotModified 

116 response.raise_for_status() 

117 return ( 

118 open_image(io.BytesIO(await response.read())), 

119 response.headers, 

120 ) 

121 

122 async def __is_modified(self, url: str, headers: dict[str, str]) -> bool: 

123 async with self.http.head(url, headers=headers) as response: 

124 return response.status != HTTPStatus.NOT_MODIFIED 

125 

126 async def url_modified(self, url: str) -> bool: 

127 with self.store.session() as orm: 

128 cached = orm.query(Avatar).filter_by(url=url).one_or_none() 

129 if cached is None: 

130 return True 

131 headers = self.__get_http_headers(cached) 

132 return await self.__is_modified(url, headers) 

133 

134 @staticmethod 

135 async def _get_image(avatar: AvatarType) -> Image: 

136 if avatar.data is not None: 

137 return open_image(io.BytesIO(avatar.data)) 

138 elif avatar.path is not None: 

139 return open_image(avatar.path) 

140 raise TypeError("Avatar must be bytes or a Path", avatar) 

141 

142 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar: 

143 if avatar.unique_id is not None: 

144 with self.store.session() as orm: 

145 stored = ( 

146 orm.query(Avatar) 

147 .filter_by(legacy_id=str(avatar.unique_id)) 

148 .one_or_none() 

149 ) 

150 if stored is not None: 

151 return self.get(stored) 

152 

153 if avatar.url is not None: 

154 return await self.__convert_url(avatar) 

155 

156 return await self.convert(avatar, await self._get_image(avatar)) 

157 

158 async def __convert_url(self, avatar: AvatarType) -> CachedAvatar: 

159 assert avatar.url is not None 

160 async with self.lock(avatar.unique_id or avatar.url): 

161 with self.store.session() as orm: 

162 if avatar.unique_id is None: 

163 stored = orm.query(Avatar).filter_by(url=avatar.url).one_or_none() 

164 else: 

165 stored = ( 

166 orm.query(Avatar) 

167 .filter_by(legacy_id=str(avatar.unique_id)) 

168 .one_or_none() 

169 ) 

170 if stored is not None: 

171 return self.get(stored) 

172 

173 try: 

174 img, response_headers = await self.__download( 

175 avatar.url, self.__get_http_headers(stored) 

176 ) 

177 except NotModified: 

178 assert stored is not None 

179 return self.get(stored) 

180 

181 return await self.convert(avatar, img, response_headers) 

182 

183 async def convert( 

184 self, 

185 avatar: AvatarType, 

186 img: Image, 

187 response_headers: CIMultiDictProxy[str] | None = None, 

188 ) -> CachedAvatar: 

189 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size) 

190 if resize: 

191 await asyncio.get_event_loop().run_in_executor( 

192 self._thread_pool, img.thumbnail, (size, size) 

193 ) 

194 log.debug("Resampled image to %s", img.size) 

195 

196 if ( 

197 not resize 

198 and img.format == "PNG" 

199 and avatar.path is not None 

200 and avatar.path.exists() 

201 ): 

202 img_bytes = avatar.path.read_bytes() 

203 else: 

204 with io.BytesIO() as f: 

205 img.save(f, format="PNG") 

206 img_bytes = f.getvalue() 

207 

208 hash_ = hashlib.sha1(img_bytes).hexdigest() 

209 file_path = (self.dir / hash_).with_suffix(".png") 

210 if file_path.exists(): 

211 log.warning("Overwriting %s", file_path) 

212 with file_path.open("wb") as file: 

213 file.write(img_bytes) 

214 

215 with self.store.session(expire_on_commit=False) as orm: 

216 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar() 

217 

218 if stored is not None: 

219 if avatar.unique_id is not None: 

220 if str(avatar.unique_id) != stored.legacy_id: 

221 log.warning( 

222 "Updating the 'unique' ID of an avatar, was '%s', is now '%s'", 

223 stored.legacy_id, 

224 avatar.unique_id, 

225 ) 

226 stored.legacy_id = str(avatar.unique_id) 

227 orm.add(stored) 

228 orm.commit() 

229 

230 return self.get(stored) 

231 

232 stored = Avatar( 

233 hash=hash_, 

234 height=img.height, 

235 width=img.width, 

236 url=avatar.url, 

237 legacy_id=avatar.unique_id, 

238 ) 

239 if response_headers: 

240 stored.etag = response_headers.get("etag") 

241 stored.last_modified = response_headers.get("last-modified") 

242 

243 with self.store.session(expire_on_commit=False) as orm: 

244 if avatar.url is not None: 

245 existing = orm.execute( 

246 select(Avatar).filter_by(url=avatar.url) 

247 ).scalar_one_or_none() 

248 if existing is not None: 

249 orm.delete(existing) 

250 orm.commit() 

251 orm.add(stored) 

252 orm.commit() 

253 return self.get(stored) 

254 

255 

256avatar_cache = AvatarCache() 

257log = logging.getLogger(__name__) 

258_download_lock = asyncio.Lock() 

259 

260__all__ = ( 

261 "CachedAvatar", 

262 "avatar_cache", 

263)