Coverage for slidge/db/avatar.py: 90%

165 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1import asyncio 

2import hashlib 

3import io 

4import logging 

5from concurrent.futures import ThreadPoolExecutor 

6from http import HTTPStatus 

7from pathlib import Path 

8from typing import Optional 

9 

10import aiohttp 

11from multidict import CIMultiDictProxy 

12from PIL.Image import Image 

13from PIL.Image import open as open_image 

14from sqlalchemy import select 

15 

16from ..core import config 

17from ..util.lock import NamedLockMixin 

18from ..util.types import Avatar as AvatarType 

19from .models import Avatar 

20from .store import AvatarStore 

21 

22 

23class CachedAvatar: 

24 def __init__(self, stored: Avatar, root_dir: Path) -> None: 

25 self.stored = stored 

26 self._root = root_dir 

27 

28 @property 

29 def pk(self) -> int | None: 

30 return self.stored.id 

31 

32 @property 

33 def hash(self) -> str: 

34 return self.stored.hash 

35 

36 @property 

37 def height(self) -> int: 

38 return self.stored.height 

39 

40 @property 

41 def width(self) -> int: 

42 return self.stored.width 

43 

44 @property 

45 def etag(self) -> str | None: 

46 return self.stored.etag 

47 

48 @property 

49 def last_modified(self) -> str | None: 

50 return self.stored.last_modified 

51 

52 @property 

53 def data(self): 

54 return self.path.read_bytes() 

55 

56 @property 

57 def path(self): 

58 return (self._root / self.hash).with_suffix(".png") 

59 

60 

61class NotModified(Exception): 

62 pass 

63 

64 

65class AvatarCache(NamedLockMixin): 

66 dir: Path 

67 http: aiohttp.ClientSession 

68 store: AvatarStore 

69 

70 def __init__(self) -> None: 

71 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS) 

72 super().__init__() 

73 

74 def get(self, stored: Avatar) -> CachedAvatar: 

75 return CachedAvatar(stored, self.dir) 

76 

77 def set_dir(self, path: Path) -> None: 

78 self.dir = path 

79 self.dir.mkdir(exist_ok=True) 

80 log.debug("Checking avatar files") 

81 with self.store.session(expire_on_commit=False) as orm: 

82 for stored in orm.query(Avatar).all(): 

83 avatar = CachedAvatar(stored, path) 

84 if avatar.path.exists(): 

85 continue 

86 log.warning( 

87 "Removing avatar %s from store because %s does not exist", 

88 avatar.hash, 

89 avatar.path, 

90 ) 

91 orm.delete(stored) 

92 orm.commit() 

93 

94 def close(self) -> None: 

95 self._thread_pool.shutdown(cancel_futures=True) 

96 

97 def __get_http_headers(self, cached: Optional[CachedAvatar | Avatar] = None): 

98 headers = {} 

99 if cached and (self.dir / cached.hash).with_suffix(".png").exists(): 

100 if last_modified := cached.last_modified: 

101 headers["If-Modified-Since"] = last_modified 

102 if etag := cached.etag: 

103 headers["If-None-Match"] = etag 

104 return headers 

105 

106 async def __download( 

107 self, 

108 url: str, 

109 headers: dict[str, str], 

110 ) -> tuple[Image, CIMultiDictProxy[str]]: 

111 async with self.http.get(url, headers=headers) as response: 

112 if response.status == HTTPStatus.NOT_MODIFIED: 

113 log.debug("Using avatar cache for %s", url) 

114 raise NotModified 

115 response.raise_for_status() 

116 return ( 

117 open_image(io.BytesIO(await response.read())), 

118 response.headers, 

119 ) 

120 

121 async def __is_modified(self, url, headers) -> bool: 

122 async with self.http.head(url, headers=headers) as response: 

123 return response.status != HTTPStatus.NOT_MODIFIED 

124 

125 async def url_modified(self, url: str) -> bool: 

126 with self.store.session() as orm: 

127 cached = orm.query(Avatar).filter_by(url=url).one_or_none() 

128 if cached is None: 

129 return True 

130 headers = self.__get_http_headers(cached) 

131 return await self.__is_modified(url, headers) 

132 

133 @staticmethod 

134 async def _get_image(avatar: AvatarType) -> Image: 

135 if avatar.data is not None: 

136 return open_image(io.BytesIO(avatar.data)) 

137 elif avatar.path is not None: 

138 return open_image(avatar.path) 

139 raise TypeError("Avatar must be bytes or a Path", avatar) 

140 

141 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar: 

142 if avatar.unique_id is not None: 

143 with self.store.session() as orm: 

144 stored = ( 

145 orm.query(Avatar) 

146 .filter_by(legacy_id=str(avatar.unique_id)) 

147 .one_or_none() 

148 ) 

149 if stored is not None: 

150 return self.get(stored) 

151 

152 if avatar.url is not None: 

153 return await self.__convert_url(avatar) 

154 

155 return await self.convert(avatar, await self._get_image(avatar)) 

156 

157 async def __convert_url(self, avatar: AvatarType) -> CachedAvatar: 

158 assert avatar.url is not None 

159 async with self.lock(avatar.unique_id or avatar.url): 

160 with self.store.session() as orm: 

161 if avatar.unique_id is None: 

162 stored = orm.query(Avatar).filter_by(url=avatar.url).one_or_none() 

163 else: 

164 stored = ( 

165 orm.query(Avatar) 

166 .filter_by(legacy_id=str(avatar.unique_id)) 

167 .one_or_none() 

168 ) 

169 if stored is not None: 

170 return self.get(stored) 

171 

172 try: 

173 img, response_headers = await self.__download( 

174 avatar.url, self.__get_http_headers(stored) 

175 ) 

176 except NotModified: 

177 assert stored is not None 

178 return self.get(stored) 

179 

180 return await self.convert(avatar, img, response_headers) 

181 

182 async def convert( 

183 self, 

184 avatar: AvatarType, 

185 img: Image, 

186 response_headers: CIMultiDictProxy[str] | None = None, 

187 ) -> CachedAvatar: 

188 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size) 

189 if resize: 

190 await asyncio.get_event_loop().run_in_executor( 

191 self._thread_pool, img.thumbnail, (size, size) 

192 ) 

193 log.debug("Resampled image to %s", img.size) 

194 

195 if ( 

196 not resize 

197 and img.format == "PNG" 

198 and avatar.path is not None 

199 and avatar.path.exists() 

200 ): 

201 img_bytes = avatar.path.read_bytes() 

202 else: 

203 with io.BytesIO() as f: 

204 img.save(f, format="PNG") 

205 img_bytes = f.getvalue() 

206 

207 hash_ = hashlib.sha1(img_bytes).hexdigest() 

208 file_path = (self.dir / hash_).with_suffix(".png") 

209 if file_path.exists(): 

210 log.warning("Overwriting %s", file_path) 

211 with file_path.open("wb") as file: 

212 file.write(img_bytes) 

213 

214 with self.store.session(expire_on_commit=False) as orm: 

215 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar() 

216 

217 if stored is not None: 

218 if avatar.unique_id is not None: 

219 if str(avatar.unique_id) != stored.legacy_id: 

220 log.warning( 

221 "Updating the 'unique' ID of an avatar, was '%s', is now '%s'", 

222 stored.legacy_id, 

223 avatar.unique_id, 

224 ) 

225 stored.legacy_id = str(avatar.unique_id) 

226 orm.add(stored) 

227 orm.commit() 

228 

229 return self.get(stored) 

230 

231 stored = Avatar( 

232 hash=hash_, 

233 height=img.height, 

234 width=img.width, 

235 url=avatar.url, 

236 legacy_id=avatar.unique_id, 

237 ) 

238 if response_headers: 

239 stored.etag = response_headers.get("etag") 

240 stored.last_modified = response_headers.get("last-modified") 

241 

242 with self.store.session(expire_on_commit=False) as orm: 

243 orm.add(stored) 

244 orm.commit() 

245 return self.get(stored) 

246 

247 

248avatar_cache = AvatarCache() 

249log = logging.getLogger(__name__) 

250_download_lock = asyncio.Lock() 

251 

252__all__ = ( 

253 "CachedAvatar", 

254 "avatar_cache", 

255)