Coverage for slidge / command / base.py: 95%

235 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 04:38 +0000

1from abc import ABC, abstractmethod 

2from collections.abc import Awaitable, Callable, Iterable, Sequence 

3from dataclasses import dataclass, field 

4from enum import Enum 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 ClassVar, 

9 Generic, 

10 TypedDict, 

11 TypeVar, 

12 Union, 

13) 

14 

15from slixmpp import JID 

16from slixmpp.exceptions import XMPPError 

17from slixmpp.plugins.xep_0004 import Form as SlixForm 

18from slixmpp.plugins.xep_0004.stanza.field import FormField as SlixFormField 

19from slixmpp.types import JidStr 

20 

21from slidge.contact import LegacyContact 

22from slidge.group import LegacyMUC 

23 

24from ..core import config 

25from ..util.types import ( 

26 AnyMUC, 

27 AnySession, 

28 FieldType, 

29 LegacyContactType, 

30 LegacyMUCType, 

31 SessionType, 

32) 

33 

34NODE_PREFIX = "https://slidge.im/command/core/" 

35 

36if TYPE_CHECKING: 

37 from ..util.types import AnyGateway 

38 from .categories import CommandCategory 

39 

40 

41HandlerType = ( 

42 Callable[[AnySession, JID], "CommandResponseType"] 

43 | Callable[[AnySession, JID], Awaitable["CommandResponseType"]] 

44) 

45 

46FormValues = dict[str, str | JID | bool | list[str] | list[JID]] 

47 

48 

49@dataclass 

50class TableResult: 

51 """ 

52 Structured data as the result of a command 

53 """ 

54 

55 fields: Sequence["FormField"] 

56 """ 

57 The 'columns names' of the table. 

58 """ 

59 items: Sequence[dict[str, str | JID]] 

60 """ 

61 The rows of the table. Each row is a dict where keys are the fields ``var`` 

62 attribute. 

63 """ 

64 description: str 

65 """ 

66 A description of the content of the table. 

67 """ 

68 

69 jids_are_mucs: bool = False 

70 

71 def get_xml(self) -> SlixForm: 

72 """ 

73 Get a slixmpp "form" (with <reported> header)to represent the data 

74 

75 :return: some XML 

76 """ 

77 form = SlixForm() 

78 form["type"] = "result" 

79 form["title"] = self.description 

80 for f in self.fields: 

81 form.add_reported(f.var, label=f.label, type=f.type) 

82 for item in self.items: 

83 form.add_item({k: str(v) for k, v in item.items()}) 

84 return form 

85 

86 

87@dataclass 

88class SearchResult(TableResult): 

89 """ 

90 Results of the search command (search for contacts via Jabber Search) 

91 

92 Return type of :meth:`BaseSession.search`. 

93 """ 

94 

95 description: str = "Contact search results" 

96 

97 

98@dataclass 

99class Confirmation: 

100 """ 

101 A confirmation 'dialog' 

102 """ 

103 

104 prompt: str 

105 """ 

106 The text presented to the command triggering user 

107 """ 

108 handler: Any 

109 """ 

110 An async function that should return a ResponseType 

111 """ 

112 success: str | None = None 

113 """ 

114 Text in case of success, used if handler does not return anything 

115 """ 

116 handler_args: Iterable[Any] = field(default_factory=list) 

117 """ 

118 arguments passed to the handler 

119 """ 

120 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

121 """ 

122 keyword arguments passed to the handler 

123 """ 

124 

125 def get_form(self) -> SlixForm: 

126 """ 

127 Get the slixmpp form 

128 

129 :return: some xml 

130 """ 

131 form = SlixForm() 

132 form["type"] = "form" 

133 form["title"] = self.prompt 

134 form.append( 

135 FormField( 

136 "confirm", type="boolean", value="true", label="Confirm" 

137 ).get_xml() 

138 ) 

139 return form 

140 

141 

142@dataclass 

143class ConfirmationSession(Confirmation, Generic[SessionType]): 

144 handler: Callable[ 

145 [SessionType | None, JID], 

146 Awaitable["CommandResponseSessionType[SessionType]"], 

147 ] 

148 

149 

150RecipientType = TypeVar("RecipientType", bound=LegacyContact | LegacyMUC[Any]) 

151 

152 

153@dataclass 

154class ConfirmationRecipient(Confirmation, Generic[RecipientType]): 

155 handler: Callable[ 

156 [RecipientType], 

157 Awaitable["CommandResponseRecipientType[RecipientType]"], 

158 ] 

159 

160 

161@dataclass 

162class Form: 

163 """ 

164 A form, to request user input 

165 """ 

166 

167 title: str 

168 instructions: str 

169 fields: Sequence["FormField"] 

170 handler: Any 

171 handler_args: Iterable[Any] = field(default_factory=list) 

172 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

173 timeout_handler: Callable[[], None] | None = None 

174 

175 def get_values( 

176 self, slix_form: SlixForm 

177 ) -> dict[str, list[str] | list[JID] | str | JID | bool | None]: 

178 """ 

179 Parse form submission 

180 

181 :param slix_form: the xml received as the submission of a form 

182 :return: A dict where keys=field.var and values are either strings 

183 or JIDs (if field.type=jid-single) 

184 """ 

185 str_values: dict[str, str] = slix_form.get_values() 

186 values = {} 

187 for f in self.fields: 

188 values[f.var] = f.validate(str_values.get(f.var)) 

189 return values 

190 

191 def get_xml(self) -> SlixForm: 

192 """ 

193 Get the slixmpp "form" 

194 

195 :return: some XML 

196 """ 

197 form = SlixForm() 

198 form["type"] = "form" 

199 form["title"] = self.title 

200 form["instructions"] = self.instructions 

201 for fi in self.fields: 

202 form.append(fi.get_xml()) 

203 return form 

204 

205 

206class FormSession(Form, Generic[SessionType]): 

207 handler: Callable[ 

208 [FormValues, SessionType | None, JID], 

209 Awaitable["CommandResponseSessionType[SessionType]"], 

210 ] 

211 

212 

213@dataclass 

214class FormRecipient(Form, Generic[RecipientType]): 

215 handler: Callable[ 

216 [RecipientType, FormValues], 

217 Awaitable["CommandResponseRecipientType[RecipientType]"], 

218 ] 

219 

220 

221class CommandAccess(int, Enum): 

222 """ 

223 Defines who can access a given Command 

224 """ 

225 

226 ADMIN_ONLY = 0 

227 USER = 1 

228 USER_LOGGED = 2 

229 USER_NON_LOGGED = 3 

230 NON_USER = 4 

231 ANY = 5 

232 

233 

234class Option(TypedDict): 

235 """ 

236 Options to be used for ``FormField``s of type ``list-*`` 

237 """ 

238 

239 label: str 

240 value: str 

241 

242 

243# TODO: support forms validation XEP-0122 

244@dataclass 

245class FormField: 

246 """ 

247 Represents a field of the form that a user will see when registering to the gateway 

248 via their XMPP client. 

249 """ 

250 

251 var: str = "" 

252 """ 

253 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

254 """ 

255 label: str | None = None 

256 """Description of the field that the user will see""" 

257 required: bool = False 

258 """Whether this field is mandatory or not""" 

259 private: bool = False 

260 """ 

261 For sensitive info that should not be displayed on screen while the user types. 

262 Forces field_type to "text-private" 

263 """ 

264 type: FieldType = "text-single" 

265 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

266 value: str = "" 

267 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

268 options: list[Option] | None = None 

269 

270 image_url: str | None = None 

271 """An image associated to this field, eg, a QR code""" 

272 

273 def __post_init__(self) -> None: 

274 if self.private: 

275 self.type = "text-private" 

276 

277 def __acceptable_options(self) -> list[str]: 

278 if self.options is None: 

279 raise RuntimeError 

280 return [x["value"] for x in self.options] 

281 

282 def validate( 

283 self, value: str | list[str] | None 

284 ) -> list[str] | list[JID] | str | JID | bool | None: 

285 """ 

286 Raise appropriate XMPPError if a given value is valid for this field 

287 

288 :param value: The value to test 

289 :return: The same value OR a JID if ``self.type=jid-single`` 

290 """ 

291 if isinstance(value, list) and not self.type.endswith("multi"): 

292 raise XMPPError("not-acceptable", "A single value was expected") 

293 

294 if self.type in ("list-multi", "jid-multi", "text-multi"): 

295 if not value: 

296 value = [] 

297 if isinstance(value, list): 

298 if self.type == "text-multi": 

299 return value 

300 return self.__validate_list_multi(value) 

301 else: 

302 raise XMPPError("not-acceptable", "Multiple values was expected") 

303 

304 assert isinstance(value, (str, bool, JID)) or value is None 

305 

306 if self.required and value is None: 

307 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

308 

309 if value is None: 

310 return None 

311 

312 if self.type == "jid-single": 

313 try: 

314 return JID(value) 

315 except ValueError: 

316 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

317 

318 elif self.type == "list-single": 

319 if value not in self.__acceptable_options(): 

320 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

321 

322 elif self.type == "boolean": 

323 return value.lower() in ("1", "true") if isinstance(value, str) else value 

324 

325 return value 

326 

327 def __validate_list_multi(self, value: list[str]) -> list[str] | list[JID]: 

328 for v in value: 

329 if v not in self.__acceptable_options(): 

330 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

331 if self.type == "list-multi": 

332 return value 

333 return [JID(v) for v in value] 

334 

335 def get_xml(self) -> SlixFormField: 

336 """ 

337 Get the field in slixmpp format 

338 

339 :return: some XML 

340 """ 

341 f = SlixFormField() 

342 f["var"] = self.var 

343 f["label"] = self.label 

344 f["required"] = self.required 

345 f["type"] = self.type 

346 if self.options: 

347 for o in self.options: 

348 f.add_option(**o) 

349 f["value"] = self.value 

350 if self.image_url: 

351 f["media"].add_uri(self.image_url, itype="image/png") 

352 return f 

353 

354 

355CommandResponseType = TableResult | Confirmation | Form | str | None 

356 

357CommandResponseSessionType = ( 

358 TableResult 

359 | ConfirmationSession[SessionType] 

360 | FormSession[SessionType] 

361 | str 

362 | None 

363) 

364 

365CommandResponseRecipientType = ( 

366 TableResult 

367 | ConfirmationRecipient[RecipientType] 

368 | FormRecipient[RecipientType] 

369 | str 

370 | None 

371) 

372 

373 

374class _CommandMixin(ABC): 

375 NAME: str = NotImplemented 

376 """ 

377 Friendly name of the command, eg: "do something with stuff" 

378 """ 

379 HELP: str = NotImplemented 

380 """ 

381 Long description of what the command does 

382 """ 

383 NODE: str = NotImplemented 

384 """ 

385 Name of the node used for ad-hoc commands 

386 """ 

387 CHAT_COMMAND: str = NotImplemented 

388 """ 

389 Text to send to the gateway to trigger the command via a message 

390 """ 

391 

392 

393class Command(_CommandMixin, Generic[SessionType]): 

394 """ 

395 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

396 """ 

397 

398 ACCESS: "CommandAccess" = NotImplemented 

399 """ 

400 Who can use this command 

401 """ 

402 

403 CATEGORY: Union[str, "CommandCategory"] | None = None 

404 """ 

405 If used, the command will be under this top-level category. 

406 Use the same string for several commands to group them. 

407 This hierarchy only used for the adhoc interface, not the chat command 

408 interface. 

409 """ 

410 

411 subclasses: ClassVar[list[type["Command[SessionType]"]]] = [] 

412 

413 def __init__(self, xmpp: "AnyGateway") -> None: 

414 self.xmpp = xmpp 

415 

416 def __init_subclass__( 

417 cls, 

418 **kwargs: Any, # noqa:ANN401 

419 ) -> None: 

420 # store subclasses so subclassing is enough for the command to be 

421 # picked up by slidge 

422 cls.subclasses.append(cls) 

423 

424 async def run( 

425 self, 

426 session: SessionType | None, 

427 ifrom: JID, 

428 *args: str, 

429 ) -> CommandResponseSessionType[SessionType]: 

430 """ 

431 Entry point of the command 

432 

433 :param session: If triggered by a registered user, its slidge Session 

434 :param ifrom: JID of the command-triggering entity 

435 :param args: When triggered via chatbot type message, additional words 

436 after the CHAT_COMMAND string was passed 

437 

438 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

439 """ 

440 raise XMPPError("feature-not-implemented") 

441 

442 def _get_session(self, jid: JID) -> SessionType | None: 

443 return self.xmpp.get_session_from_jid(jid) # type:ignore 

444 

445 def __can_use_command(self, jid: JID) -> bool: 

446 j = jid.bare 

447 return bool(self.xmpp.jid_validator.match(j) or j in config.ADMINS) 

448 

449 def raise_if_not_authorized( 

450 self, 

451 jid: JID, 

452 fetch_session: bool = True, 

453 session: SessionType | None = None, 

454 ) -> SessionType | None: 

455 """ 

456 Raise an appropriate error is jid is not authorized to use the command 

457 

458 :param jid: jid of the entity trying to access the command 

459 :param fetch_session: 

460 :param session: 

461 

462 :return:session of JID if it exists 

463 """ 

464 if not self.__can_use_command(jid): 

465 raise XMPPError( 

466 "bad-request", "Your JID is not allowed to use this gateway." 

467 ) 

468 if fetch_session: 

469 session = self._get_session(jid) 

470 

471 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

472 raise XMPPError("not-authorized") 

473 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

474 raise XMPPError( 

475 "bad-request", "This is only available for non-users. Unregister first." 

476 ) 

477 elif self.ACCESS == CommandAccess.USER and session is None: 

478 raise XMPPError( 

479 "forbidden", 

480 "This is only available for users that are registered to this gateway", 

481 ) 

482 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

483 if session is None or session.logged: 

484 raise XMPPError( 

485 "forbidden", 

486 ( 

487 "This is only available for users that are not logged to the" 

488 " legacy service" 

489 ), 

490 ) 

491 elif self.ACCESS == CommandAccess.USER_LOGGED and ( 

492 session is None or not session.logged 

493 ): 

494 raise XMPPError( 

495 "forbidden", 

496 ("This is only available when you are logged in to the legacy service"), 

497 ) 

498 return session 

499 

500 

501T = TypeVar("T", bound="LegacyContact | AnyMUC") 

502 

503 

504class _RecipientCommand(_CommandMixin, Generic[T]): 

505 @staticmethod 

506 @abstractmethod 

507 async def run(recipient: T, *args: str) -> CommandResponseRecipientType[T]: 

508 """ 

509 Entrypoint for a recipient-specific command. 

510 

511 The first argument is a :class:`LegacyContact` or :class:`LegacyMUC` 

512 instance. ``*args`` are extra args passed when using the chatbot. 

513 """ 

514 raise NotImplementedError 

515 

516 

517class ContactCommand(_RecipientCommand[LegacyContactType], Generic[LegacyContactType]): 

518 """ 

519 A command that will be avaible on a contact. 

520 

521 It implicitly requires the user to be registered and logged. 

522 It is never instantiated, so all methods must be static methods. 

523 Its entrypoint is the ``run()`` static method. 

524 """ 

525 

526 recipient_cls = LegacyContact 

527 

528 def __init_subclass__( 

529 cls, 

530 **kwargs: Any, # noqa:ANN401 

531 ) -> None: 

532 cls.recipient_cls.commands[cls.NODE] = cls # type:ignore[assignment] 

533 cls.recipient_cls.commands_chat[cls.CHAT_COMMAND] = cls # type:ignore[assignment] 

534 

535 

536class MUCCommand(_RecipientCommand[LegacyMUCType], Generic[LegacyMUCType]): 

537 """ 

538 A command that will be avaible on a MUC. 

539 

540 It implicitly requires the user to be registered and logged. 

541 It is never instantiated, so all methods must be static methods. 

542 Its entrypoint is the ``run()`` static method. 

543 """ 

544 

545 recipient_cls = LegacyMUC 

546 

547 def __init_subclass__( 

548 cls, 

549 **kwargs: Any, # noqa:ANN401 

550 ) -> None: 

551 cls.recipient_cls.commands[cls.NODE] = cls 

552 cls.recipient_cls.commands_chat[cls.CHAT_COMMAND] = cls 

553 

554 

555def is_admin(jid: JidStr) -> bool: 

556 return JID(jid).bare in config.ADMINS