Coverage for slidge / command / base.py: 95%

236 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1from abc import ABC, abstractmethod 

2from collections.abc import Awaitable, Callable, Iterable, Sequence 

3from dataclasses import dataclass, field 

4from enum import Enum 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 ClassVar, 

9 Generic, 

10 TypedDict, 

11 TypeVar, 

12 Union, 

13) 

14 

15from slixmpp import JID 

16from slixmpp.exceptions import XMPPError 

17from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

18from slixmpp.plugins.xep_0004.stanza.field import FormField as SlixFormField 

19from slixmpp.types import JidStr 

20 

21from slidge.contact import LegacyContact 

22from slidge.group import LegacyMUC 

23 

24from ..core import config 

25from ..util.types import ( 

26 AnyContact, 

27 AnyMUC, 

28 AnySession, 

29 FieldType, 

30 LegacyContactType, 

31 LegacyMUCType, 

32 SessionType, 

33) 

34 

35NODE_PREFIX = "https://slidge.im/command/core/" 

36 

37if TYPE_CHECKING: 

38 from ..core.gateway import BaseGateway 

39 from .categories import CommandCategory 

40 

41 

42HandlerType = ( 

43 Callable[[AnySession, JID], "CommandResponseType"] 

44 | Callable[[AnySession, JID], Awaitable["CommandResponseType"]] 

45) 

46 

47FormValues = dict[str, str | JID | bool] 

48 

49 

50@dataclass 

51class TableResult: 

52 """ 

53 Structured data as the result of a command 

54 """ 

55 

56 fields: Sequence["FormField"] 

57 """ 

58 The 'columns names' of the table. 

59 """ 

60 items: Sequence[dict[str, str | JID]] 

61 """ 

62 The rows of the table. Each row is a dict where keys are the fields ``var`` 

63 attribute. 

64 """ 

65 description: str 

66 """ 

67 A description of the content of the table. 

68 """ 

69 

70 jids_are_mucs: bool = False 

71 

72 def get_xml(self) -> SlixForm: 

73 """ 

74 Get a slixmpp "form" (with <reported> header)to represent the data 

75 

76 :return: some XML 

77 """ 

78 form = SlixForm() 

79 form["type"] = "result" 

80 form["title"] = self.description 

81 for f in self.fields: 

82 form.add_reported(f.var, label=f.label, type=f.type) 

83 for item in self.items: 

84 form.add_item({k: str(v) for k, v in item.items()}) 

85 return form 

86 

87 

88@dataclass 

89class SearchResult(TableResult): 

90 """ 

91 Results of the search command (search for contacts via Jabber Search) 

92 

93 Return type of :meth:`BaseSession.search`. 

94 """ 

95 

96 description: str = "Contact search results" 

97 

98 

99@dataclass 

100class Confirmation: 

101 """ 

102 A confirmation 'dialog' 

103 """ 

104 

105 prompt: str 

106 """ 

107 The text presented to the command triggering user 

108 """ 

109 handler: Any 

110 """ 

111 An async function that should return a ResponseType 

112 """ 

113 success: str | None = None 

114 """ 

115 Text in case of success, used if handler does not return anything 

116 """ 

117 handler_args: Iterable[Any] = field(default_factory=list) 

118 """ 

119 arguments passed to the handler 

120 """ 

121 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

122 """ 

123 keyword arguments passed to the handler 

124 """ 

125 

126 def get_form(self) -> SlixForm: 

127 """ 

128 Get the slixmpp form 

129 

130 :return: some xml 

131 """ 

132 form = SlixForm() 

133 form["type"] = "form" 

134 form["title"] = self.prompt 

135 form.append( 

136 FormField( 

137 "confirm", type="boolean", value="true", label="Confirm" 

138 ).get_xml() 

139 ) 

140 return form 

141 

142 

143@dataclass 

144class ConfirmationSession(Confirmation, Generic[SessionType]): 

145 handler: Callable[ 

146 [SessionType | None, JID], 

147 Awaitable["CommandResponseSessionType[SessionType]"], 

148 ] 

149 

150 

151RecipientType = TypeVar( 

152 "RecipientType", bound=LegacyContact[Any] | LegacyMUC[Any, Any, Any, Any] 

153) 

154 

155 

156@dataclass 

157class ConfirmationRecipient(Confirmation, Generic[RecipientType]): 

158 handler: Callable[ 

159 [RecipientType], 

160 Awaitable["CommandResponseRecipientType[RecipientType]"], 

161 ] 

162 

163 

164@dataclass 

165class Form: 

166 """ 

167 A form, to request user input 

168 """ 

169 

170 title: str 

171 instructions: str 

172 fields: Sequence["FormField"] 

173 handler: Any 

174 handler_args: Iterable[Any] = field(default_factory=list) 

175 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

176 timeout_handler: Callable[[], None] | None = None 

177 

178 def get_values( 

179 self, slix_form: SlixForm 

180 ) -> dict[str, list[str] | list[JID] | str | JID | bool | None]: 

181 """ 

182 Parse form submission 

183 

184 :param slix_form: the xml received as the submission of a form 

185 :return: A dict where keys=field.var and values are either strings 

186 or JIDs (if field.type=jid-single) 

187 """ 

188 str_values: dict[str, str] = slix_form.get_values() 

189 values = {} 

190 for f in self.fields: 

191 values[f.var] = f.validate(str_values.get(f.var)) 

192 return values 

193 

194 def get_xml(self) -> SlixForm: 

195 """ 

196 Get the slixmpp "form" 

197 

198 :return: some XML 

199 """ 

200 form = SlixForm() 

201 form["type"] = "form" 

202 form["title"] = self.title 

203 form["instructions"] = self.instructions 

204 for fi in self.fields: 

205 form.append(fi.get_xml()) 

206 return form 

207 

208 

209class FormSession(Form, Generic[SessionType]): 

210 handler: Callable[ 

211 [FormValues, SessionType | None, JID], 

212 Awaitable["CommandResponseSessionType[SessionType]"], 

213 ] 

214 

215 

216@dataclass 

217class FormRecipient(Form, Generic[RecipientType]): 

218 handler: Callable[ 

219 [RecipientType, FormValues], 

220 Awaitable["CommandResponseRecipientType[RecipientType]"], 

221 ] 

222 

223 

224class CommandAccess(int, Enum): 

225 """ 

226 Defines who can access a given Command 

227 """ 

228 

229 ADMIN_ONLY = 0 

230 USER = 1 

231 USER_LOGGED = 2 

232 USER_NON_LOGGED = 3 

233 NON_USER = 4 

234 ANY = 5 

235 

236 

237class Option(TypedDict): 

238 """ 

239 Options to be used for ``FormField``s of type ``list-*`` 

240 """ 

241 

242 label: str 

243 value: str 

244 

245 

246# TODO: support forms validation XEP-0122 

247@dataclass 

248class FormField: 

249 """ 

250 Represents a field of the form that a user will see when registering to the gateway 

251 via their XMPP client. 

252 """ 

253 

254 var: str = "" 

255 """ 

256 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

257 """ 

258 label: str | None = None 

259 """Description of the field that the user will see""" 

260 required: bool = False 

261 """Whether this field is mandatory or not""" 

262 private: bool = False 

263 """ 

264 For sensitive info that should not be displayed on screen while the user types. 

265 Forces field_type to "text-private" 

266 """ 

267 type: FieldType = "text-single" 

268 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

269 value: str = "" 

270 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

271 options: list[Option] | None = None 

272 

273 image_url: str | None = None 

274 """An image associated to this field, eg, a QR code""" 

275 

276 def __post_init__(self) -> None: 

277 if self.private: 

278 self.type = "text-private" 

279 

280 def __acceptable_options(self) -> list[str]: 

281 if self.options is None: 

282 raise RuntimeError 

283 return [x["value"] for x in self.options] 

284 

285 def validate( 

286 self, value: str | list[str] | None 

287 ) -> list[str] | list[JID] | str | JID | bool | None: 

288 """ 

289 Raise appropriate XMPPError if a given value is valid for this field 

290 

291 :param value: The value to test 

292 :return: The same value OR a JID if ``self.type=jid-single`` 

293 """ 

294 if isinstance(value, list) and not self.type.endswith("multi"): 

295 raise XMPPError("not-acceptable", "A single value was expected") 

296 

297 if self.type in ("list-multi", "jid-multi", "text-multi"): 

298 if not value: 

299 value = [] 

300 if isinstance(value, list): 

301 if self.type == "text-multi": 

302 return value 

303 return self.__validate_list_multi(value) 

304 else: 

305 raise XMPPError("not-acceptable", "Multiple values was expected") 

306 

307 assert isinstance(value, (str, bool, JID)) or value is None 

308 

309 if self.required and value is None: 

310 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

311 

312 if value is None: 

313 return None 

314 

315 if self.type == "jid-single": 

316 try: 

317 return JID(value) 

318 except ValueError: 

319 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

320 

321 elif self.type == "list-single": 

322 if value not in self.__acceptable_options(): 

323 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

324 

325 elif self.type == "boolean": 

326 return value.lower() in ("1", "true") if isinstance(value, str) else value 

327 

328 return value 

329 

330 def __validate_list_multi(self, value: list[str]) -> list[str] | list[JID]: 

331 for v in value: 

332 if v not in self.__acceptable_options(): 

333 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

334 if self.type == "list-multi": 

335 return value 

336 return [JID(v) for v in value] 

337 

338 def get_xml(self) -> SlixFormField: 

339 """ 

340 Get the field in slixmpp format 

341 

342 :return: some XML 

343 """ 

344 f = SlixFormField() 

345 f["var"] = self.var 

346 f["label"] = self.label 

347 f["required"] = self.required 

348 f["type"] = self.type 

349 if self.options: 

350 for o in self.options: 

351 f.add_option(**o) 

352 f["value"] = self.value 

353 if self.image_url: 

354 f["media"].add_uri(self.image_url, itype="image/png") 

355 return f 

356 

357 

358CommandResponseType = TableResult | Confirmation | Form | str | None 

359 

360CommandResponseSessionType = ( 

361 TableResult 

362 | ConfirmationSession[SessionType] 

363 | FormSession[SessionType] 

364 | str 

365 | None 

366) 

367 

368CommandResponseRecipientType = ( 

369 TableResult 

370 | ConfirmationRecipient[RecipientType] 

371 | FormRecipient[RecipientType] 

372 | str 

373 | None 

374) 

375 

376 

377class _CommandMixin(ABC): 

378 NAME: str = NotImplemented 

379 """ 

380 Friendly name of the command, eg: "do something with stuff" 

381 """ 

382 HELP: str = NotImplemented 

383 """ 

384 Long description of what the command does 

385 """ 

386 NODE: str = NotImplemented 

387 """ 

388 Name of the node used for ad-hoc commands 

389 """ 

390 CHAT_COMMAND: str = NotImplemented 

391 """ 

392 Text to send to the gateway to trigger the command via a message 

393 """ 

394 

395 

396class Command(_CommandMixin, Generic[SessionType]): 

397 """ 

398 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

399 """ 

400 

401 ACCESS: "CommandAccess" = NotImplemented 

402 """ 

403 Who can use this command 

404 """ 

405 

406 CATEGORY: Union[str, "CommandCategory"] | None = None 

407 """ 

408 If used, the command will be under this top-level category. 

409 Use the same string for several commands to group them. 

410 This hierarchy only used for the adhoc interface, not the chat command 

411 interface. 

412 """ 

413 

414 subclasses: ClassVar[list[type["Command[SessionType]"]]] = [] 

415 

416 def __init__(self, xmpp: "BaseGateway") -> None: 

417 self.xmpp = xmpp 

418 

419 def __init_subclass__( 

420 cls, 

421 **kwargs: Any, # noqa:ANN401 

422 ) -> None: 

423 # store subclasses so subclassing is enough for the command to be 

424 # picked up by slidge 

425 cls.subclasses.append(cls) 

426 

427 async def run( 

428 self, 

429 session: SessionType | None, 

430 ifrom: JID, 

431 *args: str, 

432 ) -> CommandResponseSessionType[SessionType]: 

433 """ 

434 Entry point of the command 

435 

436 :param session: If triggered by a registered user, its slidge Session 

437 :param ifrom: JID of the command-triggering entity 

438 :param args: When triggered via chatbot type message, additional words 

439 after the CHAT_COMMAND string was passed 

440 

441 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

442 """ 

443 raise XMPPError("feature-not-implemented") 

444 

445 def _get_session(self, jid: JID) -> SessionType | None: 

446 return self.xmpp.get_session_from_jid(jid) # type:ignore 

447 

448 def __can_use_command(self, jid: JID) -> bool: 

449 j = jid.bare 

450 return bool(self.xmpp.jid_validator.match(j) or j in config.ADMINS) 

451 

452 def raise_if_not_authorized( 

453 self, 

454 jid: JID, 

455 fetch_session: bool = True, 

456 session: SessionType | None = None, 

457 ) -> SessionType | None: 

458 """ 

459 Raise an appropriate error is jid is not authorized to use the command 

460 

461 :param jid: jid of the entity trying to access the command 

462 :param fetch_session: 

463 :param session: 

464 

465 :return:session of JID if it exists 

466 """ 

467 if not self.__can_use_command(jid): 

468 raise XMPPError( 

469 "bad-request", "Your JID is not allowed to use this gateway." 

470 ) 

471 if fetch_session: 

472 session = self._get_session(jid) 

473 

474 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

475 raise XMPPError("not-authorized") 

476 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

477 raise XMPPError( 

478 "bad-request", "This is only available for non-users. Unregister first." 

479 ) 

480 elif self.ACCESS == CommandAccess.USER and session is None: 

481 raise XMPPError( 

482 "forbidden", 

483 "This is only available for users that are registered to this gateway", 

484 ) 

485 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

486 if session is None or session.logged: 

487 raise XMPPError( 

488 "forbidden", 

489 ( 

490 "This is only available for users that are not logged to the" 

491 " legacy service" 

492 ), 

493 ) 

494 elif self.ACCESS == CommandAccess.USER_LOGGED: 

495 if session is None or not session.logged: 

496 raise XMPPError( 

497 "forbidden", 

498 ( 

499 "This is only available when you are logged in to the legacy" 

500 " service" 

501 ), 

502 ) 

503 return session 

504 

505 

506T = TypeVar("T", bound=AnyContact | AnyMUC) 

507 

508 

509class _RecipientCommand(_CommandMixin, Generic[T]): 

510 @staticmethod 

511 @abstractmethod 

512 async def run( 

513 recipient: T, *args: str 

514 ) -> CommandResponseRecipientType[RecipientType]: 

515 """ 

516 Entrypoint for a recipient-specific command. 

517 

518 The first argument is a :class:`LegacyContact` or :class:`LegacyMUC` 

519 instance. ``*args`` are extra args passed when using the chatbot. 

520 """ 

521 raise NotImplementedError 

522 

523 

524class ContactCommand(_RecipientCommand[AnyContact], Generic[LegacyContactType]): 

525 """ 

526 A command that will be avaible on a contact. 

527 

528 It implicitly requires the user to be registered and logged. 

529 It is never instantiated, so all methods must be static methods. 

530 Its entrypoint is the ``run()`` static method. 

531 """ 

532 

533 recipient_cls = LegacyContact 

534 

535 def __init_subclass__( 

536 cls, 

537 **kwargs: Any, # noqa:ANN401 

538 ) -> None: 

539 cls.recipient_cls.commands[cls.NODE] = cls # type:ignore[assignment] 

540 cls.recipient_cls.commands_chat[cls.CHAT_COMMAND] = cls # type:ignore[assignment] 

541 

542 

543class MUCCommand(_RecipientCommand[AnyMUC], Generic[LegacyMUCType]): 

544 """ 

545 A command that will be avaible on a MUC. 

546 

547 It implicitly requires the user to be registered and logged. 

548 It is never instantiated, so all methods must be static methods. 

549 Its entrypoint is the ``run()`` static method. 

550 """ 

551 

552 recipient_cls = LegacyMUC 

553 

554 def __init_subclass__( 

555 cls, 

556 **kwargs: Any, # noqa:ANN401 

557 ) -> None: 

558 cls.recipient_cls.commands[cls.NODE] = cls 

559 cls.recipient_cls.commands_chat[cls.CHAT_COMMAND] = cls 

560 

561 

562def is_admin(jid: JidStr) -> bool: 

563 return JID(jid).bare in config.ADMINS