Coverage for slidge/command/base.py: 94%

204 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1from abc import ABC 

2from dataclasses import dataclass, field 

3from enum import Enum 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Awaitable, 

8 Callable, 

9 Iterable, 

10 Optional, 

11 Sequence, 

12 Type, 

13 TypedDict, 

14 Union, 

15) 

16 

17from slixmpp import JID # type: ignore[attr-defined] 

18from slixmpp.exceptions import XMPPError 

19from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

20from slixmpp.plugins.xep_0004 import ( 

21 FormField as SlixFormField, # type: ignore[attr-defined] 

22) 

23from slixmpp.types import JidStr 

24 

25from ..core import config 

26from ..util.types import AnyBaseSession, FieldType 

27 

28NODE_PREFIX = "https://slidge.im/command/core/" 

29 

30if TYPE_CHECKING: 

31 from ..core.gateway import BaseGateway 

32 from ..core.session import BaseSession 

33 from .categories import CommandCategory 

34 

35 

36HandlerType = Union[ 

37 Callable[[AnyBaseSession, JID], "CommandResponseType"], 

38 Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]], 

39] 

40 

41FormValues = dict[str, Union[str, JID, bool]] 

42 

43 

44FormHandlerType = Callable[ 

45 [FormValues, AnyBaseSession, JID], 

46 Awaitable["CommandResponseType"], 

47] 

48 

49ConfirmationHandlerType = Callable[ 

50 [Optional[AnyBaseSession], JID], Awaitable["CommandResponseType"] 

51] 

52 

53 

54@dataclass 

55class TableResult: 

56 """ 

57 Structured data as the result of a command 

58 """ 

59 

60 fields: Sequence["FormField"] 

61 """ 

62 The 'columns names' of the table. 

63 """ 

64 items: Sequence[dict[str, Union[str, JID]]] 

65 """ 

66 The rows of the table. Each row is a dict where keys are the fields ``var`` 

67 attribute. 

68 """ 

69 description: str 

70 """ 

71 A description of the content of the table. 

72 """ 

73 

74 jids_are_mucs: bool = False 

75 

76 def get_xml(self) -> SlixForm: 

77 """ 

78 Get a slixmpp "form" (with <reported> header)to represent the data 

79 

80 :return: some XML 

81 """ 

82 form = SlixForm() # type: ignore[no-untyped-call] 

83 form["type"] = "result" 

84 form["title"] = self.description 

85 for f in self.fields: 

86 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call] 

87 for item in self.items: 

88 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call] 

89 return form 

90 

91 

92@dataclass 

93class SearchResult(TableResult): 

94 """ 

95 Results of the search command (search for contacts via Jabber Search) 

96 

97 Return type of :meth:`BaseSession.search`. 

98 """ 

99 

100 description: str = "Contact search results" 

101 

102 

103@dataclass 

104class Confirmation: 

105 """ 

106 A confirmation 'dialog' 

107 """ 

108 

109 prompt: str 

110 """ 

111 The text presented to the command triggering user 

112 """ 

113 handler: ConfirmationHandlerType 

114 """ 

115 An async function that should return a ResponseType 

116 """ 

117 success: Optional[str] = None 

118 """ 

119 Text in case of success, used if handler does not return anything 

120 """ 

121 handler_args: Iterable[Any] = field(default_factory=list) 

122 """ 

123 arguments passed to the handler 

124 """ 

125 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

126 """ 

127 keyword arguments passed to the handler 

128 """ 

129 

130 def get_form(self) -> SlixForm: 

131 """ 

132 Get the slixmpp form 

133 

134 :return: some xml 

135 """ 

136 form = SlixForm() # type: ignore[no-untyped-call] 

137 form["type"] = "form" 

138 form["title"] = self.prompt 

139 form.append( 

140 FormField( 

141 "confirm", type="boolean", value="true", label="Confirm" 

142 ).get_xml() 

143 ) 

144 return form 

145 

146 

147@dataclass 

148class Form: 

149 """ 

150 A form, to request user input 

151 """ 

152 

153 title: str 

154 instructions: str 

155 fields: Sequence["FormField"] 

156 handler: FormHandlerType 

157 handler_args: Iterable[Any] = field(default_factory=list) 

158 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

159 

160 def get_values( 

161 self, slix_form: SlixForm 

162 ) -> dict[str, Union[list[str], list[JID], str, JID, bool, None]]: 

163 """ 

164 Parse form submission 

165 

166 :param slix_form: the xml received as the submission of a form 

167 :return: A dict where keys=field.var and values are either strings 

168 or JIDs (if field.type=jid-single) 

169 """ 

170 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call] 

171 values = {} 

172 for f in self.fields: 

173 values[f.var] = f.validate(str_values.get(f.var)) 

174 return values 

175 

176 def get_xml(self) -> SlixForm: 

177 """ 

178 Get the slixmpp "form" 

179 

180 :return: some XML 

181 """ 

182 form = SlixForm() # type: ignore[no-untyped-call] 

183 form["type"] = "form" 

184 form["title"] = self.title 

185 form["instructions"] = self.instructions 

186 for fi in self.fields: 

187 form.append(fi.get_xml()) 

188 return form 

189 

190 

191class CommandAccess(int, Enum): 

192 """ 

193 Defines who can access a given Command 

194 """ 

195 

196 ADMIN_ONLY = 0 

197 USER = 1 

198 USER_LOGGED = 2 

199 USER_NON_LOGGED = 3 

200 NON_USER = 4 

201 ANY = 5 

202 

203 

204class Option(TypedDict): 

205 """ 

206 Options to be used for ``FormField``s of type ``list-*`` 

207 """ 

208 

209 label: str 

210 value: str 

211 

212 

213# TODO: support forms validation XEP-0122 

214@dataclass 

215class FormField: 

216 """ 

217 Represents a field of the form that a user will see when registering to the gateway 

218 via their XMPP client. 

219 """ 

220 

221 var: str = "" 

222 """ 

223 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

224 """ 

225 label: Optional[str] = None 

226 """Description of the field that the user will see""" 

227 required: bool = False 

228 """Whether this field is mandatory or not""" 

229 private: bool = False 

230 """ 

231 For sensitive info that should not be displayed on screen while the user types. 

232 Forces field_type to "text-private" 

233 """ 

234 type: FieldType = "text-single" 

235 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

236 value: str = "" 

237 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

238 options: Optional[list[Option]] = None 

239 

240 image_url: Optional[str] = None 

241 """An image associated to this field, eg, a QR code""" 

242 

243 def __post_init__(self) -> None: 

244 if self.private: 

245 self.type = "text-private" 

246 

247 def __acceptable_options(self) -> list[str]: 

248 if self.options is None: 

249 raise RuntimeError 

250 return [x["value"] for x in self.options] 

251 

252 def validate( 

253 self, value: Optional[Union[str, list[str]]] 

254 ) -> Union[list[str], list[JID], str, JID, bool, None]: 

255 """ 

256 Raise appropriate XMPPError if a given value is valid for this field 

257 

258 :param value: The value to test 

259 :return: The same value OR a JID if ``self.type=jid-single`` 

260 """ 

261 if isinstance(value, list) and not self.type.endswith("multi"): 

262 raise XMPPError("not-acceptable", "A single value was expected") 

263 

264 if self.type in ("list-multi", "jid-multi"): 

265 if not value: 

266 value = [] 

267 if isinstance(value, list): 

268 return self.__validate_list_multi(value) 

269 else: 

270 raise XMPPError("not-acceptable", "Multiple values was expected") 

271 

272 assert isinstance(value, (str, bool, JID)) or value is None 

273 

274 if self.required and value is None: 

275 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

276 

277 if value is None: 

278 return None 

279 

280 if self.type == "jid-single": 

281 try: 

282 return JID(value) 

283 except ValueError: 

284 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

285 

286 elif self.type == "list-single": 

287 if value not in self.__acceptable_options(): 

288 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

289 

290 elif self.type == "boolean": 

291 return value.lower() in ("1", "true") if isinstance(value, str) else value 

292 

293 return value 

294 

295 def __validate_list_multi(self, value: list[str]) -> Union[list[str], list[JID]]: 

296 for v in value: 

297 if v not in self.__acceptable_options(): 

298 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

299 if self.type == "list-multi": 

300 return value 

301 return [JID(v) for v in value] 

302 

303 def get_xml(self) -> SlixFormField: 

304 """ 

305 Get the field in slixmpp format 

306 

307 :return: some XML 

308 """ 

309 f = SlixFormField() 

310 f["var"] = self.var 

311 f["label"] = self.label 

312 f["required"] = self.required 

313 f["type"] = self.type 

314 if self.options: 

315 for o in self.options: 

316 f.add_option(**o) # type: ignore[no-untyped-call] 

317 f["value"] = self.value 

318 if self.image_url: 

319 f["media"].add_uri(self.image_url, itype="image/png") 

320 return f 

321 

322 

323CommandResponseType = Union[TableResult, Confirmation, Form, str, None] 

324 

325 

326class Command(ABC): 

327 """ 

328 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

329 """ 

330 

331 NAME: str = NotImplemented 

332 """ 

333 Friendly name of the command, eg: "do something with stuff" 

334 """ 

335 HELP: str = NotImplemented 

336 """ 

337 Long description of what the command does 

338 """ 

339 NODE: str = NotImplemented 

340 """ 

341 Name of the node used for ad-hoc commands 

342 """ 

343 CHAT_COMMAND: str = NotImplemented 

344 """ 

345 Text to send to the gateway to trigger the command via a message 

346 """ 

347 

348 ACCESS: "CommandAccess" = NotImplemented 

349 """ 

350 Who can use this command 

351 """ 

352 

353 CATEGORY: Optional[Union[str, "CommandCategory"]] = None 

354 """ 

355 If used, the command will be under this top-level category. 

356 Use the same string for several commands to group them. 

357 This hierarchy only used for the adhoc interface, not the chat command 

358 interface. 

359 """ 

360 

361 subclasses = list[Type["Command"]]() 

362 

363 def __init__(self, xmpp: "BaseGateway"): 

364 self.xmpp = xmpp 

365 

366 def __init_subclass__(cls, **kwargs: Any) -> None: 

367 # store subclasses so subclassing is enough for the command to be 

368 # picked up by slidge 

369 cls.subclasses.append(cls) 

370 

371 async def run( 

372 self, session: Optional["BaseSession[Any, Any]"], ifrom: JID, *args: str 

373 ) -> CommandResponseType: 

374 """ 

375 Entry point of the command 

376 

377 :param session: If triggered by a registered user, its slidge Session 

378 :param ifrom: JID of the command-triggering entity 

379 :param args: When triggered via chatbot type message, additional words 

380 after the CHAT_COMMAND string was passed 

381 

382 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

383 """ 

384 raise XMPPError("feature-not-implemented") 

385 

386 def _get_session(self, jid: JID) -> Optional["BaseSession[Any, Any]"]: 

387 user = self.xmpp.store.users.get(jid) 

388 if user is None: 

389 return None 

390 

391 return self.xmpp.get_session_from_user(user) 

392 

393 def __can_use_command(self, jid: JID): 

394 j = jid.bare 

395 return self.xmpp.jid_validator.match(j) or j in config.ADMINS 

396 

397 def raise_if_not_authorized(self, jid: JID) -> Optional["BaseSession[Any, Any]"]: 

398 """ 

399 Raise an appropriate error is jid is not authorized to use the command 

400 

401 :param jid: jid of the entity trying to access the command 

402 :return:session of JID if it exists 

403 """ 

404 session = self._get_session(jid) 

405 if not self.__can_use_command(jid): 

406 raise XMPPError( 

407 "bad-request", "Your JID is not allowed to use this gateway." 

408 ) 

409 

410 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

411 raise XMPPError("not-authorized") 

412 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

413 raise XMPPError( 

414 "bad-request", "This is only available for non-users. Unregister first." 

415 ) 

416 elif self.ACCESS == CommandAccess.USER and session is None: 

417 raise XMPPError( 

418 "forbidden", 

419 "This is only available for users that are registered to this gateway", 

420 ) 

421 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

422 if session is None or session.logged: 

423 raise XMPPError( 

424 "forbidden", 

425 ( 

426 "This is only available for users that are not logged to the" 

427 " legacy service" 

428 ), 

429 ) 

430 elif self.ACCESS == CommandAccess.USER_LOGGED: 

431 if session is None or not session.logged: 

432 raise XMPPError( 

433 "forbidden", 

434 ( 

435 "This is only available when you are logged in to the legacy" 

436 " service" 

437 ), 

438 ) 

439 return session 

440 

441 

442def is_admin(jid: JidStr) -> bool: 

443 return JID(jid).bare in config.ADMINS