Coverage for slidge/command/base.py: 93%

205 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1from abc import ABC 

2from dataclasses import dataclass, field 

3from enum import Enum 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Awaitable, 

8 Callable, 

9 Iterable, 

10 Optional, 

11 Sequence, 

12 Type, 

13 TypedDict, 

14 Union, 

15) 

16 

17from slixmpp import JID 

18from slixmpp.exceptions import XMPPError 

19from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

20from slixmpp.plugins.xep_0004 import FormField as SlixFormField 

21from slixmpp.types import JidStr 

22 

23from ..core import config 

24from ..db.models import GatewayUser 

25from ..util.types import AnyBaseSession, FieldType 

26 

27NODE_PREFIX = "https://slidge.im/command/core/" 

28 

29if TYPE_CHECKING: 

30 from ..core.gateway import BaseGateway 

31 from ..core.session import BaseSession 

32 from .categories import CommandCategory 

33 

34 

35HandlerType = Union[ 

36 Callable[[AnyBaseSession, JID], "CommandResponseType"], 

37 Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]], 

38] 

39 

40FormValues = dict[str, Union[str, JID, bool]] 

41 

42 

43FormHandlerType = Callable[ 

44 [FormValues, AnyBaseSession, JID], 

45 Awaitable["CommandResponseType"], 

46] 

47 

48ConfirmationHandlerType = Callable[ 

49 [Optional[AnyBaseSession], JID], Awaitable["CommandResponseType"] 

50] 

51 

52 

53@dataclass 

54class TableResult: 

55 """ 

56 Structured data as the result of a command 

57 """ 

58 

59 fields: Sequence["FormField"] 

60 """ 

61 The 'columns names' of the table. 

62 """ 

63 items: Sequence[dict[str, Union[str, JID]]] 

64 """ 

65 The rows of the table. Each row is a dict where keys are the fields ``var`` 

66 attribute. 

67 """ 

68 description: str 

69 """ 

70 A description of the content of the table. 

71 """ 

72 

73 jids_are_mucs: bool = False 

74 

75 def get_xml(self) -> SlixForm: 

76 """ 

77 Get a slixmpp "form" (with <reported> header)to represent the data 

78 

79 :return: some XML 

80 """ 

81 form = SlixForm() # type: ignore[no-untyped-call] 

82 form["type"] = "result" 

83 form["title"] = self.description 

84 for f in self.fields: 

85 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call] 

86 for item in self.items: 

87 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call] 

88 return form 

89 

90 

91@dataclass 

92class SearchResult(TableResult): 

93 """ 

94 Results of the search command (search for contacts via Jabber Search) 

95 

96 Return type of :meth:`BaseSession.search`. 

97 """ 

98 

99 description: str = "Contact search results" 

100 

101 

102@dataclass 

103class Confirmation: 

104 """ 

105 A confirmation 'dialog' 

106 """ 

107 

108 prompt: str 

109 """ 

110 The text presented to the command triggering user 

111 """ 

112 handler: ConfirmationHandlerType 

113 """ 

114 An async function that should return a ResponseType 

115 """ 

116 success: Optional[str] = None 

117 """ 

118 Text in case of success, used if handler does not return anything 

119 """ 

120 handler_args: Iterable[Any] = field(default_factory=list) 

121 """ 

122 arguments passed to the handler 

123 """ 

124 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

125 """ 

126 keyword arguments passed to the handler 

127 """ 

128 

129 def get_form(self) -> SlixForm: 

130 """ 

131 Get the slixmpp form 

132 

133 :return: some xml 

134 """ 

135 form = SlixForm() # type: ignore[no-untyped-call] 

136 form["type"] = "form" 

137 form["title"] = self.prompt 

138 form.append( 

139 FormField( 

140 "confirm", type="boolean", value="true", label="Confirm" 

141 ).get_xml() 

142 ) 

143 return form 

144 

145 

146@dataclass 

147class Form: 

148 """ 

149 A form, to request user input 

150 """ 

151 

152 title: str 

153 instructions: str 

154 fields: Sequence["FormField"] 

155 handler: FormHandlerType 

156 handler_args: Iterable[Any] = field(default_factory=list) 

157 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

158 

159 def get_values( 

160 self, slix_form: SlixForm 

161 ) -> dict[str, Union[list[str], list[JID], str, JID, bool, None]]: 

162 """ 

163 Parse form submission 

164 

165 :param slix_form: the xml received as the submission of a form 

166 :return: A dict where keys=field.var and values are either strings 

167 or JIDs (if field.type=jid-single) 

168 """ 

169 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call] 

170 values = {} 

171 for f in self.fields: 

172 values[f.var] = f.validate(str_values.get(f.var)) 

173 return values 

174 

175 def get_xml(self) -> SlixForm: 

176 """ 

177 Get the slixmpp "form" 

178 

179 :return: some XML 

180 """ 

181 form = SlixForm() # type: ignore[no-untyped-call] 

182 form["type"] = "form" 

183 form["title"] = self.title 

184 form["instructions"] = self.instructions 

185 for fi in self.fields: 

186 form.append(fi.get_xml()) 

187 return form 

188 

189 

190class CommandAccess(int, Enum): 

191 """ 

192 Defines who can access a given Command 

193 """ 

194 

195 ADMIN_ONLY = 0 

196 USER = 1 

197 USER_LOGGED = 2 

198 USER_NON_LOGGED = 3 

199 NON_USER = 4 

200 ANY = 5 

201 

202 

203class Option(TypedDict): 

204 """ 

205 Options to be used for ``FormField``s of type ``list-*`` 

206 """ 

207 

208 label: str 

209 value: str 

210 

211 

212# TODO: support forms validation XEP-0122 

213@dataclass 

214class FormField: 

215 """ 

216 Represents a field of the form that a user will see when registering to the gateway 

217 via their XMPP client. 

218 """ 

219 

220 var: str = "" 

221 """ 

222 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

223 """ 

224 label: Optional[str] = None 

225 """Description of the field that the user will see""" 

226 required: bool = False 

227 """Whether this field is mandatory or not""" 

228 private: bool = False 

229 """ 

230 For sensitive info that should not be displayed on screen while the user types. 

231 Forces field_type to "text-private" 

232 """ 

233 type: FieldType = "text-single" 

234 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

235 value: str = "" 

236 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

237 options: Optional[list[Option]] = None 

238 

239 image_url: Optional[str] = None 

240 """An image associated to this field, eg, a QR code""" 

241 

242 def __post_init__(self) -> None: 

243 if self.private: 

244 self.type = "text-private" 

245 

246 def __acceptable_options(self) -> list[str]: 

247 if self.options is None: 

248 raise RuntimeError 

249 return [x["value"] for x in self.options] 

250 

251 def validate( 

252 self, value: Optional[Union[str, list[str]]] 

253 ) -> Union[list[str], list[JID], str, JID, bool, None]: 

254 """ 

255 Raise appropriate XMPPError if a given value is valid for this field 

256 

257 :param value: The value to test 

258 :return: The same value OR a JID if ``self.type=jid-single`` 

259 """ 

260 if isinstance(value, list) and not self.type.endswith("multi"): 

261 raise XMPPError("not-acceptable", "A single value was expected") 

262 

263 if self.type in ("list-multi", "jid-multi"): 

264 if not value: 

265 value = [] 

266 if isinstance(value, list): 

267 return self.__validate_list_multi(value) 

268 else: 

269 raise XMPPError("not-acceptable", "Multiple values was expected") 

270 

271 assert isinstance(value, (str, bool, JID)) or value is None 

272 

273 if self.required and value is None: 

274 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

275 

276 if value is None: 

277 return None 

278 

279 if self.type == "jid-single": 

280 try: 

281 return JID(value) 

282 except ValueError: 

283 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

284 

285 elif self.type == "list-single": 

286 if value not in self.__acceptable_options(): 

287 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

288 

289 elif self.type == "boolean": 

290 return value.lower() in ("1", "true") if isinstance(value, str) else value 

291 

292 return value 

293 

294 def __validate_list_multi(self, value: list[str]) -> Union[list[str], list[JID]]: 

295 # COMPAT: all the "if v" and "if not v" are workarounds for https://codeberg.org/slidge/slidge/issues/43 

296 # They should be reverted once the bug is fixed upstream, cf https://soprani.ca/todo/390 

297 for v in value: 

298 if v not in self.__acceptable_options(): 

299 if not v: 

300 continue 

301 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

302 if self.type == "list-multi": 

303 return [v for v in value if v] 

304 return [JID(v) for v in value if v] 

305 

306 def get_xml(self) -> SlixFormField: 

307 """ 

308 Get the field in slixmpp format 

309 

310 :return: some XML 

311 """ 

312 f = SlixFormField() 

313 f["var"] = self.var 

314 f["label"] = self.label 

315 f["required"] = self.required 

316 f["type"] = self.type 

317 if self.options: 

318 for o in self.options: 

319 f.add_option(**o) # type: ignore[no-untyped-call] 

320 f["value"] = self.value 

321 if self.image_url: 

322 f["media"].add_uri(self.image_url, itype="image/png") 

323 return f 

324 

325 

326CommandResponseType = Union[TableResult, Confirmation, Form, str, None] 

327 

328 

329class Command(ABC): 

330 """ 

331 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

332 """ 

333 

334 NAME: str = NotImplemented 

335 """ 

336 Friendly name of the command, eg: "do something with stuff" 

337 """ 

338 HELP: str = NotImplemented 

339 """ 

340 Long description of what the command does 

341 """ 

342 NODE: str = NotImplemented 

343 """ 

344 Name of the node used for ad-hoc commands 

345 """ 

346 CHAT_COMMAND: str = NotImplemented 

347 """ 

348 Text to send to the gateway to trigger the command via a message 

349 """ 

350 

351 ACCESS: "CommandAccess" = NotImplemented 

352 """ 

353 Who can use this command 

354 """ 

355 

356 CATEGORY: Optional[Union[str, "CommandCategory"]] = None 

357 """ 

358 If used, the command will be under this top-level category. 

359 Use the same string for several commands to group them. 

360 This hierarchy only used for the adhoc interface, not the chat command 

361 interface. 

362 """ 

363 

364 subclasses = list[Type["Command"]]() 

365 

366 def __init__(self, xmpp: "BaseGateway") -> None: 

367 self.xmpp = xmpp 

368 

369 def __init_subclass__(cls, **kwargs: Any) -> None: 

370 # store subclasses so subclassing is enough for the command to be 

371 # picked up by slidge 

372 cls.subclasses.append(cls) 

373 

374 async def run( 

375 self, session: Optional["BaseSession[Any, Any]"], ifrom: JID, *args: str 

376 ) -> CommandResponseType: 

377 """ 

378 Entry point of the command 

379 

380 :param session: If triggered by a registered user, its slidge Session 

381 :param ifrom: JID of the command-triggering entity 

382 :param args: When triggered via chatbot type message, additional words 

383 after the CHAT_COMMAND string was passed 

384 

385 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

386 """ 

387 raise XMPPError("feature-not-implemented") 

388 

389 def _get_session(self, jid: JID) -> Optional["BaseSession[Any, Any]"]: 

390 return self.xmpp.get_session_from_jid(jid) 

391 

392 def __can_use_command(self, jid: JID): 

393 j = jid.bare 

394 return self.xmpp.jid_validator.match(j) or j in config.ADMINS 

395 

396 def raise_if_not_authorized( 

397 self, 

398 jid: JID, 

399 fetch_session: bool = True, 

400 session: Optional["BaseSession[Any, Any]"] = None, 

401 ) -> Optional["BaseSession[Any, Any]"]: 

402 """ 

403 Raise an appropriate error is jid is not authorized to use the command 

404 

405 :param jid: jid of the entity trying to access the command 

406 :param fetch_session: 

407 :param session: 

408 

409 :return:session of JID if it exists 

410 """ 

411 if not self.__can_use_command(jid): 

412 raise XMPPError( 

413 "bad-request", "Your JID is not allowed to use this gateway." 

414 ) 

415 if fetch_session: 

416 session = self._get_session(jid) 

417 

418 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

419 raise XMPPError("not-authorized") 

420 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

421 raise XMPPError( 

422 "bad-request", "This is only available for non-users. Unregister first." 

423 ) 

424 elif self.ACCESS == CommandAccess.USER and session is None: 

425 raise XMPPError( 

426 "forbidden", 

427 "This is only available for users that are registered to this gateway", 

428 ) 

429 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

430 if session is None or session.logged: 

431 raise XMPPError( 

432 "forbidden", 

433 ( 

434 "This is only available for users that are not logged to the" 

435 " legacy service" 

436 ), 

437 ) 

438 elif self.ACCESS == CommandAccess.USER_LOGGED: 

439 if session is None or not session.logged: 

440 raise XMPPError( 

441 "forbidden", 

442 ( 

443 "This is only available when you are logged in to the legacy" 

444 " service" 

445 ), 

446 ) 

447 return session 

448 

449 

450def is_admin(jid: JidStr) -> bool: 

451 return JID(jid).bare in config.ADMINS