Coverage for slidge/db/avatar.py: 92%

132 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import asyncio 

2import hashlib 

3import io 

4import logging 

5import uuid 

6from concurrent.futures import ThreadPoolExecutor 

7from dataclasses import dataclass 

8from http import HTTPStatus 

9from pathlib import Path 

10from typing import Optional 

11 

12import aiohttp 

13from multidict import CIMultiDictProxy 

14from PIL.Image import Image 

15from PIL.Image import open as open_image 

16from sqlalchemy import select 

17 

18from slidge.core import config 

19from slidge.db.models import Avatar 

20from slidge.db.store import AvatarStore 

21from slidge.util.types import URL, AvatarType 

22 

23 

24@dataclass 

25class CachedAvatar: 

26 pk: int 

27 filename: str 

28 hash: str 

29 height: int 

30 width: int 

31 root: Path 

32 etag: Optional[str] = None 

33 last_modified: Optional[str] = None 

34 

35 @property 

36 def data(self): 

37 return self.path.read_bytes() 

38 

39 @property 

40 def path(self): 

41 return self.root / self.filename 

42 

43 @staticmethod 

44 def from_store(stored: Avatar, root_dir: Path): 

45 return CachedAvatar( 

46 pk=stored.id, 

47 filename=stored.filename, 

48 hash=stored.hash, 

49 height=stored.height, 

50 width=stored.width, 

51 etag=stored.etag, 

52 root=root_dir, 

53 last_modified=stored.last_modified, 

54 ) 

55 

56 

57class NotModified(Exception): 

58 pass 

59 

60 

61class AvatarCache: 

62 dir: Path 

63 http: aiohttp.ClientSession 

64 store: AvatarStore 

65 

66 def __init__(self): 

67 self._thread_pool = ThreadPoolExecutor(config.AVATAR_RESAMPLING_THREADS) 

68 

69 def set_dir(self, path: Path): 

70 self.dir = path 

71 self.dir.mkdir(exist_ok=True) 

72 with self.store.session(): 

73 for stored in self.store.get_all(): 

74 avatar = CachedAvatar.from_store(stored, root_dir=path) 

75 if avatar.path.exists(): 

76 continue 

77 log.warning( 

78 "Removing avatar %s from store because %s does not exist", 

79 avatar.hash, 

80 avatar.path, 

81 ) 

82 self.store.delete_by_pk(stored.id) 

83 

84 def close(self): 

85 self._thread_pool.shutdown(cancel_futures=True) 

86 

87 def __get_http_headers(self, cached: Optional[CachedAvatar | Avatar]): 

88 headers = {} 

89 if cached and (self.dir / cached.filename).exists(): 

90 if last_modified := cached.last_modified: 

91 headers["If-Modified-Since"] = last_modified 

92 if etag := cached.etag: 

93 headers["If-None-Match"] = etag 

94 return headers 

95 

96 async def __download( 

97 self, 

98 url: str, 

99 headers: dict[str, str], 

100 ) -> tuple[Image, CIMultiDictProxy[str]]: 

101 async with self.http.get(url, headers=headers) as response: 

102 if response.status == HTTPStatus.NOT_MODIFIED: 

103 log.debug("Using avatar cache for %s", url) 

104 raise NotModified 

105 return ( 

106 open_image(io.BytesIO(await response.read())), 

107 response.headers, 

108 ) 

109 

110 async def __is_modified(self, url, headers) -> bool: 

111 async with self.http.head(url, headers=headers) as response: 

112 return response.status != HTTPStatus.NOT_MODIFIED 

113 

114 async def url_modified(self, url: URL) -> bool: 

115 cached = self.store.get_by_url(url) 

116 if cached is None: 

117 return True 

118 headers = self.__get_http_headers(cached) 

119 return await self.__is_modified(url, headers) 

120 

121 def get_by_pk(self, pk: int) -> CachedAvatar: 

122 stored = self.store.get_by_pk(pk) 

123 assert stored is not None 

124 return CachedAvatar.from_store(stored, self.dir) 

125 

126 @staticmethod 

127 async def _get_image(avatar: AvatarType) -> Image: 

128 if isinstance(avatar, bytes): 

129 return open_image(io.BytesIO(avatar)) 

130 elif isinstance(avatar, Path): 

131 return open_image(avatar) 

132 raise TypeError("Avatar must be bytes or a Path", avatar) 

133 

134 async def convert_or_get(self, avatar: AvatarType) -> CachedAvatar: 

135 if isinstance(avatar, (URL, str)): 

136 with self.store.session(): 

137 stored = self.store.get_by_url(avatar) 

138 try: 

139 img, response_headers = await self.__download( 

140 avatar, self.__get_http_headers(stored) 

141 ) 

142 except NotModified: 

143 assert stored is not None 

144 return CachedAvatar.from_store(stored, self.dir) 

145 else: 

146 img = await self._get_image(avatar) 

147 response_headers = None 

148 with self.store.session() as orm: 

149 resize = (size := config.AVATAR_SIZE) and any(x > size for x in img.size) 

150 if resize: 

151 await asyncio.get_event_loop().run_in_executor( 

152 self._thread_pool, img.thumbnail, (size, size) 

153 ) 

154 log.debug("Resampled image to %s", img.size) 

155 

156 filename = str(uuid.uuid1()) + ".png" 

157 file_path = self.dir / filename 

158 

159 if ( 

160 not resize 

161 and img.format == "PNG" 

162 and isinstance(avatar, (str, Path)) 

163 and (path := Path(avatar)) 

164 and path.exists() 

165 ): 

166 img_bytes = path.read_bytes() 

167 else: 

168 with io.BytesIO() as f: 

169 img.save(f, format="PNG") 

170 img_bytes = f.getvalue() 

171 

172 with file_path.open("wb") as file: 

173 file.write(img_bytes) 

174 

175 hash_ = hashlib.sha1(img_bytes).hexdigest() 

176 

177 stored = orm.execute(select(Avatar).where(Avatar.hash == hash_)).scalar() 

178 

179 if stored is not None: 

180 return CachedAvatar.from_store(stored, self.dir) 

181 

182 stored = Avatar( 

183 filename=filename, 

184 hash=hash_, 

185 height=img.height, 

186 width=img.width, 

187 url=avatar if isinstance(avatar, (URL, str)) else None, 

188 ) 

189 if response_headers: 

190 stored.etag = response_headers.get("etag") 

191 stored.last_modified = response_headers.get("last-modified") 

192 

193 orm.add(stored) 

194 orm.commit() 

195 return CachedAvatar.from_store(stored, self.dir) 

196 

197 

198avatar_cache = AvatarCache() 

199log = logging.getLogger(__name__) 

200_download_lock = asyncio.Lock() 

201 

202__all__ = ( 

203 "CachedAvatar", 

204 "avatar_cache", 

205)