Coverage for ivatar/utils.py: 68%

188 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-21 23:07 +0000

1""" 

2Simple module providing reusable random_string function 

3""" 

4 

5import contextlib 

6import http.client 

7import random 

8import string 

9import logging 

10from io import BytesIO 

11from urllib.parse import urlparse 

12from urllib.error import URLError 

13from urllib.request import urlopen as urlopen_orig 

14 

15import requests 

16from PIL import Image, ImageDraw, ImageSequence 

17from ivatar.settings import DEBUG, URL_TIMEOUT 

18 

19# Initialize logger 

20logger = logging.getLogger("ivatar") 

21 

22BLUESKY_IDENTIFIER = None 

23BLUESKY_APP_PASSWORD = None 

24with contextlib.suppress(Exception): 

25 from ivatar.settings import BLUESKY_IDENTIFIER, BLUESKY_APP_PASSWORD 

26 

27 

28def urlopen(url, timeout=URL_TIMEOUT): 

29 ctx = None 

30 if DEBUG: 

31 import ssl 

32 

33 ctx = ssl.create_default_context() 

34 ctx.check_hostname = False 

35 ctx.verify_mode = ssl.CERT_NONE 

36 

37 try: 

38 return urlopen_orig(url, timeout=timeout, context=ctx) 

39 except Exception as exc: 

40 # Handle malformed URLs and other HTTP client errors gracefully 

41 if isinstance(exc, http.client.InvalidURL): 

42 logger.warning( 

43 f"Invalid URL detected (possible injection attempt): {url!r} - {exc}" 

44 ) 

45 # Re-raise as URLError to maintain compatibility with existing error handling 

46 raise URLError(f"Invalid URL: {exc}") from exc 

47 elif isinstance(exc, (ValueError, UnicodeError)): 

48 logger.warning(f"Malformed URL detected: {url!r} - {exc}") 

49 raise URLError(f"Malformed URL: {exc}") from exc 

50 else: 

51 # Re-raise other exceptions as-is 

52 raise 

53 

54 

55class Bluesky: 

56 """ 

57 Handle Bluesky client access with persistent session management 

58 """ 

59 

60 identifier = "" 

61 app_password = "" 

62 service = "https://bsky.social" 

63 session = None 

64 _shared_session = None # Class-level shared session 

65 _session_expires_at = None # Track session expiration 

66 

67 def __init__( 

68 self, 

69 identifier: str = BLUESKY_IDENTIFIER, 

70 app_password: str = BLUESKY_APP_PASSWORD, 

71 service: str = "https://bsky.social", 

72 ): 

73 self.identifier = identifier 

74 self.app_password = app_password 

75 self.service = service 

76 

77 def _is_session_valid(self) -> bool: 

78 """ 

79 Check if the current session is still valid 

80 """ 

81 if not self._shared_session or not self._session_expires_at: 

82 return False 

83 

84 import time 

85 

86 # Add 5 minute buffer before actual expiration 

87 return time.time() < (self._session_expires_at - 300) 

88 

89 def login(self): 

90 """ 

91 Login to Bluesky with session persistence 

92 """ 

93 # Use shared session if available and valid 

94 if self._is_session_valid(): 

95 self.session = self._shared_session 

96 logger.debug("Reusing existing Bluesky session") 

97 return 

98 

99 logger.debug("Creating new Bluesky session") 

100 auth_response = requests.post( 

101 f"{self.service}/xrpc/com.atproto.server.createSession", 

102 json={"identifier": self.identifier, "password": self.app_password}, 

103 ) 

104 auth_response.raise_for_status() 

105 self.session = auth_response.json() 

106 

107 # Store session data for reuse 

108 self._shared_session = self.session 

109 import time 

110 

111 # Sessions typically expire in 24 hours, but we'll refresh every 12 hours 

112 self._session_expires_at = time.time() + (12 * 60 * 60) 

113 

114 logger.debug( 

115 "Created new Bluesky session, expires at: %s", 

116 time.strftime( 

117 "%Y-%m-%d %H:%M:%S", time.localtime(self._session_expires_at) 

118 ), 

119 ) 

120 

121 @classmethod 

122 def clear_shared_session(cls): 

123 """ 

124 Clear the shared session (useful for testing) 

125 """ 

126 cls._shared_session = None 

127 cls._session_expires_at = None 

128 logger.debug("Cleared shared Bluesky session") 

129 

130 def normalize_handle(self, handle: str) -> str: 

131 """ 

132 Return the normalized handle for given handle 

133 """ 

134 # Normalize Bluesky handle in case someone enters an '@' at the beginning 

135 while handle.startswith("@"): 

136 handle = handle[1:] 

137 # Remove trailing spaces or spaces at the beginning 

138 while handle.startswith(" "): 

139 handle = handle[1:] 

140 while handle.endswith(" "): 

141 handle = handle[:-1] 

142 return handle 

143 

144 def _make_profile_request(self, handle: str): 

145 """ 

146 Make a profile request to Bluesky API with automatic retry on session expiration 

147 """ 

148 try: 

149 profile_response = requests.get( 

150 f"{self.service}/xrpc/app.bsky.actor.getProfile", 

151 headers={"Authorization": f'Bearer {self.session["accessJwt"]}'}, 

152 params={"actor": handle}, 

153 ) 

154 profile_response.raise_for_status() 

155 return profile_response.json() 

156 except requests.exceptions.HTTPError as exc: 

157 if exc.response.status_code == 401: 

158 # Session expired, try to login again 

159 logger.warning("Bluesky session expired, re-authenticating") 

160 self.clear_shared_session() 

161 self.login() 

162 # Retry the request 

163 profile_response = requests.get( 

164 f"{self.service}/xrpc/app.bsky.actor.getProfile", 

165 headers={"Authorization": f'Bearer {self.session["accessJwt"]}'}, 

166 params={"actor": handle}, 

167 ) 

168 profile_response.raise_for_status() 

169 return profile_response.json() 

170 else: 

171 logger.warning(f"Bluesky profile fetch failed with HTTP error: {exc}") 

172 return None 

173 except Exception as exc: 

174 logger.warning(f"Bluesky profile fetch failed with error: {exc}") 

175 return None 

176 

177 def get_profile(self, handle: str) -> str: 

178 if not self.session or not self._is_session_valid(): 

179 self.login() 

180 return self._make_profile_request(handle) 

181 

182 def get_avatar(self, handle: str): 

183 """ 

184 Get avatar URL for a handle 

185 """ 

186 profile = self.get_profile(handle) 

187 return profile["avatar"] if profile else None 

188 

189 

190def random_string(length=10): 

191 """ 

192 Return some random string with default length 10 

193 """ 

194 return "".join( 

195 random.SystemRandom().choice(string.ascii_lowercase + string.digits) 

196 for _ in range(length) 

197 ) 

198 

199 

200def generate_random_email(): 

201 """ 

202 Generate a random email address using the same pattern as test_views.py 

203 """ 

204 username = random_string() 

205 domain = random_string() 

206 tld = random_string(2) 

207 return f"{username}@{domain}.{tld}" 

208 

209 

210def random_ip_address(): 

211 """ 

212 Return a random IP address (IPv4) 

213 """ 

214 return f"{random.randint(1, 254)}.{random.randint(1, 254)}.{random.randint(1, 254)}.{random.randint(1, 254)}" 

215 

216 

217def openid_variations(openid): 

218 """ 

219 Return the various OpenID variations, ALWAYS in the same order: 

220 - http w/ trailing slash 

221 - http w/o trailing slash 

222 - https w/ trailing slash 

223 - https w/o trailing slash 

224 """ 

225 

226 # Make the 'base' version: http w/ trailing slash 

227 if openid.startswith("https://"): 

228 openid = openid.replace("https://", "http://") 

229 if openid[-1] != "/": 

230 openid = f"{openid}/" 

231 

232 # http w/o trailing slash 

233 var1 = openid[:-1] 

234 var2 = openid.replace("http://", "https://") 

235 var3 = var2[:-1] 

236 return (openid, var1, var2, var3) 

237 

238 

239def mm_ng( 

240 idhash, size=80, add_red=0, add_green=0, add_blue=0 

241): # pylint: disable=too-many-locals 

242 """ 

243 Return an MM (mystery man) image, based on a given hash 

244 add some red, green or blue, if specified 

245 """ 

246 

247 # Make sure the lightest bg color we paint is e0, else 

248 # we do not see the MM any more 

249 if idhash[0] == "f": 

250 idhash = "e0" 

251 

252 # How large is the circle? 

253 circle_size = size * 0.6 

254 

255 # Coordinates for the circle 

256 start_x = int(size * 0.2) 

257 end_x = start_x + circle_size 

258 start_y = int(size * 0.05) 

259 end_y = start_y + circle_size 

260 

261 # All are the same, based on the input hash 

262 # this should always result in a "gray-ish" background 

263 red = idhash[:2] 

264 green = idhash[:2] 

265 blue = idhash[:2] 

266 

267 # Add some red (i/a) and make sure it's not over 255 

268 red = hex(int(red, 16) + add_red).replace("0x", "") 

269 if int(red, 16) > 255: 

270 red = "ff" 

271 if len(red) == 1: 

272 red = f"0{red}" 

273 

274 # Add some green (i/a) and make sure it's not over 255 

275 green = hex(int(green, 16) + add_green).replace("0x", "") 

276 if int(green, 16) > 255: 

277 green = "ff" 

278 if len(green) == 1: 

279 green = f"0{green}" 

280 

281 # Add some blue (i/a) and make sure it's not over 255 

282 blue = hex(int(blue, 16) + add_blue).replace("0x", "") 

283 if int(blue, 16) > 255: 

284 blue = "ff" 

285 if len(blue) == 1: 

286 blue = f"0{blue}" 

287 

288 # Assemble the bg color "string" in web notation. Eg. '#d3d3d3' 

289 bg_color = f"#{red}{green}{blue}" 

290 

291 # Image 

292 image = Image.new("RGB", (size, size)) 

293 draw = ImageDraw.Draw(image) 

294 

295 # Draw background 

296 draw.rectangle(((0, 0), (size, size)), fill=bg_color) 

297 

298 # Draw MMs head 

299 draw.ellipse((start_x, start_y, end_x, end_y), fill="white") 

300 

301 # Draw MMs 'body' 

302 draw.polygon( 

303 ( 

304 (start_x + circle_size / 2, size / 2.5), 

305 (size * 0.15, size), 

306 (size - size * 0.15, size), 

307 ), 

308 fill="white", 

309 ) 

310 

311 return image 

312 

313 

314def is_trusted_url(url, url_filters): 

315 """ 

316 Check if a URL is valid and considered a trusted URL. 

317 If the URL is malformed, returns False. 

318 

319 Based on: https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/API/events/UrlFilter 

320 """ 

321 scheme, netloc, path, params, query, fragment = urlparse(url) 

322 

323 for ufilter in url_filters: 

324 if "schemes" in ufilter: 

325 schemes = ufilter["schemes"] 

326 

327 if scheme not in schemes: 

328 continue 

329 

330 if "host_equals" in ufilter: 

331 host_equals = ufilter["host_equals"] 

332 

333 if netloc != host_equals: 

334 continue 

335 

336 if "host_suffix" in ufilter: 

337 host_suffix = ufilter["host_suffix"] 

338 

339 if not netloc.endswith(host_suffix): 

340 continue 

341 

342 if "path_prefix" in ufilter: 

343 path_prefix = ufilter["path_prefix"] 

344 

345 if not path.startswith(path_prefix): 

346 continue 

347 

348 if "url_prefix" in ufilter: 

349 url_prefix = ufilter["url_prefix"] 

350 

351 if not url.startswith(url_prefix): 

352 continue 

353 

354 return True 

355 

356 return False 

357 

358 

359def resize_animated_gif(input_pil: Image, size: list) -> BytesIO: 

360 def _thumbnail_frames(image): 

361 for frame in ImageSequence.Iterator(image): 

362 new_frame = frame.copy() 

363 new_frame.thumbnail(size) 

364 yield new_frame 

365 

366 frames = list(_thumbnail_frames(input_pil)) 

367 output = BytesIO() 

368 output_image = frames[0] 

369 output_image.save( 

370 output, 

371 format="gif", 

372 save_all=True, 

373 optimize=False, 

374 append_images=frames[1:], 

375 disposal=input_pil.disposal_method, 

376 **input_pil.info, 

377 ) 

378 return output