Coverage for slidge / db / avatar.py: 91%

169 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-02-15 09:02 +0000

1import asyncio 

2import hashlib 

3import io 

4import logging 

5from concurrent.futures import ThreadPoolExecutor 

6from http import HTTPStatus 

7from pathlib import Path 

8 

9import aiohttp 

10from multidict import CIMultiDictProxy 

11from PIL.Image import Image 

12from PIL.Image import open as open_image 

13from sqlalchemy import select 

14 

15from ..core import config 

16from ..util.lock import NamedLockMixin 

17from ..util.types import Avatar as AvatarType 

18from .models import Avatar 

19from .store import AvatarStore 

20 

21 

22class CachedAvatar: 

23 def __init__(self, stored: Avatar, root_dir: Path) -> None: 

24 self.stored = stored 

25 self._root = root_dir 

26 

27 @property 

28 def pk(self) -> int | None: 

29 return self.stored.id 

30 

31 @property 

32 def hash(self) -> str: 

33 return self.stored.hash 

34 

35 @property 

36 def height(self) -> int: 

37 return self.stored.height 

38 

39 @property 

40 def width(self) -> int: 

41 return self.stored.width 

42 

43 @property 

44 def etag(self) -> str | None: 

45 return self.stored.etag 

46 

47 @property 

48 def last_modified(self) -> str | None: 

49 return self.stored.last_modified 

50 

51 @property 

52 def data(self): 

53 return self.path.read_bytes() 

54 

55 @property 

56 def path(self): 

57 return (self._root / self.hash).with_suffix(".png") 

58 

59 

60class NotModified(Exception): 

61 pass 

62 

63 

64class AvatarCache(NamedLockMixin): 

65 dir: Path 

66 http: aiohttp.ClientSession 

67 store: AvatarStore 

68 

69 def __init__(self) -> None: 

70 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS) 

71 super().__init__() 

72 

73 def get(self, stored: Avatar) -> CachedAvatar: 

74 return CachedAvatar(stored, self.dir) 

75 

76 def set_dir(self, path: Path) -> None: 

77 self.dir = path 

78 self.dir.mkdir(exist_ok=True) 

79 log.debug("Checking avatar files") 

80 with self.store.session(expire_on_commit=False) as orm: 

81 for stored in orm.query(Avatar).all(): 

82 avatar = CachedAvatar(stored, path) 

83 if avatar.path.exists(): 

84 continue 

85 log.warning( 

86 "Removing avatar %s from store because %s does not exist", 

87 avatar.hash, 

88 avatar.path, 

89 ) 

90 orm.delete(stored) 

91 orm.commit() 

92 

93 def close(self) -> None: 

94 self._thread_pool.shutdown(cancel_futures=True) 

95 

96 def __get_http_headers(self, cached: CachedAvatar | Avatar | None = None): 

97 headers = {} 

98 if cached and (self.dir / cached.hash).with_suffix(".png").exists(): 

99 if last_modified := cached.last_modified: 

100 headers["If-Modified-Since"] = last_modified 

101 if etag := cached.etag: 

102 headers["If-None-Match"] = etag 

103 return headers 

104 

105 async def __download( 

106 self, 

107 url: str, 

108 headers: dict[str, str], 

109 ) -> tuple[Image, CIMultiDictProxy[str]]: 

110 async with self.http.get(url, headers=headers) as response: 

111 if response.status == HTTPStatus.NOT_MODIFIED: 

112 log.debug("Using avatar cache for %s", url) 

113 raise NotModified 

114 response.raise_for_status() 

115 return ( 

116 open_image(io.BytesIO(await response.read())), 

117 response.headers, 

118 ) 

119 

120 async def __is_modified(self, url, headers) -> bool: 

121 async with self.http.head(url, headers=headers) as response: 

122 return response.status != HTTPStatus.NOT_MODIFIED 

123 

124 async def url_modified(self, url: str) -> bool: 

125 with self.store.session() as orm: 

126 cached = orm.query(Avatar).filter_by(url=url).one_or_none() 

127 if cached is None: 

128 return True 

129 headers = self.__get_http_headers(cached) 

130 return await self.__is_modified(url, headers) 

131 

132 @staticmethod 

133 async def _get_image(avatar: AvatarType) -> Image: 

134 if avatar.data is not None: 

135 return open_image(io.BytesIO(avatar.data)) 

136 elif avatar.path is not None: 

137 return open_image(avatar.path) 

138 raise TypeError("Avatar must be bytes or a Path", avatar) 

139 

140 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar: 

141 if avatar.unique_id is not None: 

142 with self.store.session() as orm: 

143 stored = ( 

144 orm.query(Avatar) 

145 .filter_by(legacy_id=str(avatar.unique_id)) 

146 .one_or_none() 

147 ) 

148 if stored is not None: 

149 return self.get(stored) 

150 

151 if avatar.url is not None: 

152 return await self.__convert_url(avatar) 

153 

154 return await self.convert(avatar, await self._get_image(avatar)) 

155 

156 async def __convert_url(self, avatar: AvatarType) -> CachedAvatar: 

157 assert avatar.url is not None 

158 async with self.lock(avatar.unique_id or avatar.url): 

159 with self.store.session() as orm: 

160 if avatar.unique_id is None: 

161 stored = orm.query(Avatar).filter_by(url=avatar.url).one_or_none() 

162 else: 

163 stored = ( 

164 orm.query(Avatar) 

165 .filter_by(legacy_id=str(avatar.unique_id)) 

166 .one_or_none() 

167 ) 

168 if stored is not None: 

169 return self.get(stored) 

170 

171 try: 

172 img, response_headers = await self.__download( 

173 avatar.url, self.__get_http_headers(stored) 

174 ) 

175 except NotModified: 

176 assert stored is not None 

177 return self.get(stored) 

178 

179 return await self.convert(avatar, img, response_headers) 

180 

181 async def convert( 

182 self, 

183 avatar: AvatarType, 

184 img: Image, 

185 response_headers: CIMultiDictProxy[str] | None = None, 

186 ) -> CachedAvatar: 

187 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size) 

188 if resize: 

189 await asyncio.get_event_loop().run_in_executor( 

190 self._thread_pool, img.thumbnail, (size, size) 

191 ) 

192 log.debug("Resampled image to %s", img.size) 

193 

194 if ( 

195 not resize 

196 and img.format == "PNG" 

197 and avatar.path is not None 

198 and avatar.path.exists() 

199 ): 

200 img_bytes = avatar.path.read_bytes() 

201 else: 

202 with io.BytesIO() as f: 

203 img.save(f, format="PNG") 

204 img_bytes = f.getvalue() 

205 

206 hash_ = hashlib.sha1(img_bytes).hexdigest() 

207 file_path = (self.dir / hash_).with_suffix(".png") 

208 if file_path.exists(): 

209 log.warning("Overwriting %s", file_path) 

210 with file_path.open("wb") as file: 

211 file.write(img_bytes) 

212 

213 with self.store.session(expire_on_commit=False) as orm: 

214 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar() 

215 

216 if stored is not None: 

217 if avatar.unique_id is not None: 

218 if str(avatar.unique_id) != stored.legacy_id: 

219 log.warning( 

220 "Updating the 'unique' ID of an avatar, was '%s', is now '%s'", 

221 stored.legacy_id, 

222 avatar.unique_id, 

223 ) 

224 stored.legacy_id = str(avatar.unique_id) 

225 orm.add(stored) 

226 orm.commit() 

227 

228 return self.get(stored) 

229 

230 stored = Avatar( 

231 hash=hash_, 

232 height=img.height, 

233 width=img.width, 

234 url=avatar.url, 

235 legacy_id=avatar.unique_id, 

236 ) 

237 if response_headers: 

238 stored.etag = response_headers.get("etag") 

239 stored.last_modified = response_headers.get("last-modified") 

240 

241 with self.store.session(expire_on_commit=False) as orm: 

242 if avatar.url is not None: 

243 existing = orm.execute( 

244 select(Avatar).filter_by(url=avatar.url) 

245 ).scalar_one_or_none() 

246 if existing is not None: 

247 orm.delete(existing) 

248 orm.commit() 

249 orm.add(stored) 

250 orm.commit() 

251 return self.get(stored) 

252 

253 

254avatar_cache = AvatarCache() 

255log = logging.getLogger(__name__) 

256_download_lock = asyncio.Lock() 

257 

258__all__ = ( 

259 "CachedAvatar", 

260 "avatar_cache", 

261)