Coverage for slidge / command / base.py: 95%
236 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1from abc import ABC, abstractmethod
2from collections.abc import Awaitable, Callable, Iterable, Sequence
3from dataclasses import dataclass, field
4from enum import Enum
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 ClassVar,
9 Generic,
10 TypedDict,
11 TypeVar,
12 Union,
13)
15from slixmpp import JID
16from slixmpp.exceptions import XMPPError
17from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
18from slixmpp.plugins.xep_0004.stanza.field import FormField as SlixFormField
19from slixmpp.types import JidStr
21from slidge.contact import LegacyContact
22from slidge.group import LegacyMUC
24from ..core import config
25from ..util.types import (
26 AnyContact,
27 AnyMUC,
28 AnySession,
29 FieldType,
30 LegacyContactType,
31 LegacyMUCType,
32 SessionType,
33)
35NODE_PREFIX = "https://slidge.im/command/core/"
37if TYPE_CHECKING:
38 from ..core.gateway import BaseGateway
39 from .categories import CommandCategory
42HandlerType = (
43 Callable[[AnySession, JID], "CommandResponseType"]
44 | Callable[[AnySession, JID], Awaitable["CommandResponseType"]]
45)
47FormValues = dict[str, str | JID | bool]
50@dataclass
51class TableResult:
52 """
53 Structured data as the result of a command
54 """
56 fields: Sequence["FormField"]
57 """
58 The 'columns names' of the table.
59 """
60 items: Sequence[dict[str, str | JID]]
61 """
62 The rows of the table. Each row is a dict where keys are the fields ``var``
63 attribute.
64 """
65 description: str
66 """
67 A description of the content of the table.
68 """
70 jids_are_mucs: bool = False
72 def get_xml(self) -> SlixForm:
73 """
74 Get a slixmpp "form" (with <reported> header)to represent the data
76 :return: some XML
77 """
78 form = SlixForm()
79 form["type"] = "result"
80 form["title"] = self.description
81 for f in self.fields:
82 form.add_reported(f.var, label=f.label, type=f.type)
83 for item in self.items:
84 form.add_item({k: str(v) for k, v in item.items()})
85 return form
88@dataclass
89class SearchResult(TableResult):
90 """
91 Results of the search command (search for contacts via Jabber Search)
93 Return type of :meth:`BaseSession.search`.
94 """
96 description: str = "Contact search results"
99@dataclass
100class Confirmation:
101 """
102 A confirmation 'dialog'
103 """
105 prompt: str
106 """
107 The text presented to the command triggering user
108 """
109 handler: Any
110 """
111 An async function that should return a ResponseType
112 """
113 success: str | None = None
114 """
115 Text in case of success, used if handler does not return anything
116 """
117 handler_args: Iterable[Any] = field(default_factory=list)
118 """
119 arguments passed to the handler
120 """
121 handler_kwargs: dict[str, Any] = field(default_factory=dict)
122 """
123 keyword arguments passed to the handler
124 """
126 def get_form(self) -> SlixForm:
127 """
128 Get the slixmpp form
130 :return: some xml
131 """
132 form = SlixForm()
133 form["type"] = "form"
134 form["title"] = self.prompt
135 form.append(
136 FormField(
137 "confirm", type="boolean", value="true", label="Confirm"
138 ).get_xml()
139 )
140 return form
143@dataclass
144class ConfirmationSession(Confirmation, Generic[SessionType]):
145 handler: Callable[
146 [SessionType | None, JID],
147 Awaitable["CommandResponseSessionType[SessionType]"],
148 ]
151RecipientType = TypeVar(
152 "RecipientType", bound=LegacyContact[Any] | LegacyMUC[Any, Any, Any, Any]
153)
156@dataclass
157class ConfirmationRecipient(Confirmation, Generic[RecipientType]):
158 handler: Callable[
159 [RecipientType],
160 Awaitable["CommandResponseRecipientType[RecipientType]"],
161 ]
164@dataclass
165class Form:
166 """
167 A form, to request user input
168 """
170 title: str
171 instructions: str
172 fields: Sequence["FormField"]
173 handler: Any
174 handler_args: Iterable[Any] = field(default_factory=list)
175 handler_kwargs: dict[str, Any] = field(default_factory=dict)
176 timeout_handler: Callable[[], None] | None = None
178 def get_values(
179 self, slix_form: SlixForm
180 ) -> dict[str, list[str] | list[JID] | str | JID | bool | None]:
181 """
182 Parse form submission
184 :param slix_form: the xml received as the submission of a form
185 :return: A dict where keys=field.var and values are either strings
186 or JIDs (if field.type=jid-single)
187 """
188 str_values: dict[str, str] = slix_form.get_values()
189 values = {}
190 for f in self.fields:
191 values[f.var] = f.validate(str_values.get(f.var))
192 return values
194 def get_xml(self) -> SlixForm:
195 """
196 Get the slixmpp "form"
198 :return: some XML
199 """
200 form = SlixForm()
201 form["type"] = "form"
202 form["title"] = self.title
203 form["instructions"] = self.instructions
204 for fi in self.fields:
205 form.append(fi.get_xml())
206 return form
209class FormSession(Form, Generic[SessionType]):
210 handler: Callable[
211 [FormValues, SessionType | None, JID],
212 Awaitable["CommandResponseSessionType[SessionType]"],
213 ]
216@dataclass
217class FormRecipient(Form, Generic[RecipientType]):
218 handler: Callable[
219 [RecipientType, FormValues],
220 Awaitable["CommandResponseRecipientType[RecipientType]"],
221 ]
224class CommandAccess(int, Enum):
225 """
226 Defines who can access a given Command
227 """
229 ADMIN_ONLY = 0
230 USER = 1
231 USER_LOGGED = 2
232 USER_NON_LOGGED = 3
233 NON_USER = 4
234 ANY = 5
237class Option(TypedDict):
238 """
239 Options to be used for ``FormField``s of type ``list-*``
240 """
242 label: str
243 value: str
246# TODO: support forms validation XEP-0122
247@dataclass
248class FormField:
249 """
250 Represents a field of the form that a user will see when registering to the gateway
251 via their XMPP client.
252 """
254 var: str = ""
255 """
256 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form`
257 """
258 label: str | None = None
259 """Description of the field that the user will see"""
260 required: bool = False
261 """Whether this field is mandatory or not"""
262 private: bool = False
263 """
264 For sensitive info that should not be displayed on screen while the user types.
265 Forces field_type to "text-private"
266 """
267 type: FieldType = "text-single"
268 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_"""
269 value: str = ""
270 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription"""
271 options: list[Option] | None = None
273 image_url: str | None = None
274 """An image associated to this field, eg, a QR code"""
276 def __post_init__(self) -> None:
277 if self.private:
278 self.type = "text-private"
280 def __acceptable_options(self) -> list[str]:
281 if self.options is None:
282 raise RuntimeError
283 return [x["value"] for x in self.options]
285 def validate(
286 self, value: str | list[str] | None
287 ) -> list[str] | list[JID] | str | JID | bool | None:
288 """
289 Raise appropriate XMPPError if a given value is valid for this field
291 :param value: The value to test
292 :return: The same value OR a JID if ``self.type=jid-single``
293 """
294 if isinstance(value, list) and not self.type.endswith("multi"):
295 raise XMPPError("not-acceptable", "A single value was expected")
297 if self.type in ("list-multi", "jid-multi", "text-multi"):
298 if not value:
299 value = []
300 if isinstance(value, list):
301 if self.type == "text-multi":
302 return value
303 return self.__validate_list_multi(value)
304 else:
305 raise XMPPError("not-acceptable", "Multiple values was expected")
307 assert isinstance(value, (str, bool, JID)) or value is None
309 if self.required and value is None:
310 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'")
312 if value is None:
313 return None
315 if self.type == "jid-single":
316 try:
317 return JID(value)
318 except ValueError:
319 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'")
321 elif self.type == "list-single":
322 if value not in self.__acceptable_options():
323 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'")
325 elif self.type == "boolean":
326 return value.lower() in ("1", "true") if isinstance(value, str) else value
328 return value
330 def __validate_list_multi(self, value: list[str]) -> list[str] | list[JID]:
331 for v in value:
332 if v not in self.__acceptable_options():
333 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'")
334 if self.type == "list-multi":
335 return value
336 return [JID(v) for v in value]
338 def get_xml(self) -> SlixFormField:
339 """
340 Get the field in slixmpp format
342 :return: some XML
343 """
344 f = SlixFormField()
345 f["var"] = self.var
346 f["label"] = self.label
347 f["required"] = self.required
348 f["type"] = self.type
349 if self.options:
350 for o in self.options:
351 f.add_option(**o)
352 f["value"] = self.value
353 if self.image_url:
354 f["media"].add_uri(self.image_url, itype="image/png")
355 return f
358CommandResponseType = TableResult | Confirmation | Form | str | None
360CommandResponseSessionType = (
361 TableResult
362 | ConfirmationSession[SessionType]
363 | FormSession[SessionType]
364 | str
365 | None
366)
368CommandResponseRecipientType = (
369 TableResult
370 | ConfirmationRecipient[RecipientType]
371 | FormRecipient[RecipientType]
372 | str
373 | None
374)
377class _CommandMixin(ABC):
378 NAME: str = NotImplemented
379 """
380 Friendly name of the command, eg: "do something with stuff"
381 """
382 HELP: str = NotImplemented
383 """
384 Long description of what the command does
385 """
386 NODE: str = NotImplemented
387 """
388 Name of the node used for ad-hoc commands
389 """
390 CHAT_COMMAND: str = NotImplemented
391 """
392 Text to send to the gateway to trigger the command via a message
393 """
396class Command(_CommandMixin, Generic[SessionType]):
397 """
398 Abstract base class to implement gateway commands (chatbot and ad-hoc)
399 """
401 ACCESS: "CommandAccess" = NotImplemented
402 """
403 Who can use this command
404 """
406 CATEGORY: Union[str, "CommandCategory"] | None = None
407 """
408 If used, the command will be under this top-level category.
409 Use the same string for several commands to group them.
410 This hierarchy only used for the adhoc interface, not the chat command
411 interface.
412 """
414 subclasses: ClassVar[list[type["Command[SessionType]"]]] = []
416 def __init__(self, xmpp: "BaseGateway") -> None:
417 self.xmpp = xmpp
419 def __init_subclass__(
420 cls,
421 **kwargs: Any, # noqa:ANN401
422 ) -> None:
423 # store subclasses so subclassing is enough for the command to be
424 # picked up by slidge
425 cls.subclasses.append(cls)
427 async def run(
428 self,
429 session: SessionType | None,
430 ifrom: JID,
431 *args: str,
432 ) -> CommandResponseSessionType[SessionType]:
433 """
434 Entry point of the command
436 :param session: If triggered by a registered user, its slidge Session
437 :param ifrom: JID of the command-triggering entity
438 :param args: When triggered via chatbot type message, additional words
439 after the CHAT_COMMAND string was passed
441 :return: Either a TableResult, a Form, a Confirmation, a text, or None
442 """
443 raise XMPPError("feature-not-implemented")
445 def _get_session(self, jid: JID) -> SessionType | None:
446 return self.xmpp.get_session_from_jid(jid) # type:ignore
448 def __can_use_command(self, jid: JID) -> bool:
449 j = jid.bare
450 return bool(self.xmpp.jid_validator.match(j) or j in config.ADMINS)
452 def raise_if_not_authorized(
453 self,
454 jid: JID,
455 fetch_session: bool = True,
456 session: SessionType | None = None,
457 ) -> SessionType | None:
458 """
459 Raise an appropriate error is jid is not authorized to use the command
461 :param jid: jid of the entity trying to access the command
462 :param fetch_session:
463 :param session:
465 :return:session of JID if it exists
466 """
467 if not self.__can_use_command(jid):
468 raise XMPPError(
469 "bad-request", "Your JID is not allowed to use this gateway."
470 )
471 if fetch_session:
472 session = self._get_session(jid)
474 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid):
475 raise XMPPError("not-authorized")
476 elif self.ACCESS == CommandAccess.NON_USER and session is not None:
477 raise XMPPError(
478 "bad-request", "This is only available for non-users. Unregister first."
479 )
480 elif self.ACCESS == CommandAccess.USER and session is None:
481 raise XMPPError(
482 "forbidden",
483 "This is only available for users that are registered to this gateway",
484 )
485 elif self.ACCESS == CommandAccess.USER_NON_LOGGED:
486 if session is None or session.logged:
487 raise XMPPError(
488 "forbidden",
489 (
490 "This is only available for users that are not logged to the"
491 " legacy service"
492 ),
493 )
494 elif self.ACCESS == CommandAccess.USER_LOGGED:
495 if session is None or not session.logged:
496 raise XMPPError(
497 "forbidden",
498 (
499 "This is only available when you are logged in to the legacy"
500 " service"
501 ),
502 )
503 return session
506T = TypeVar("T", bound=AnyContact | AnyMUC)
509class _RecipientCommand(_CommandMixin, Generic[T]):
510 @staticmethod
511 @abstractmethod
512 async def run(
513 recipient: T, *args: str
514 ) -> CommandResponseRecipientType[RecipientType]:
515 """
516 Entrypoint for a recipient-specific command.
518 The first argument is a :class:`LegacyContact` or :class:`LegacyMUC`
519 instance. ``*args`` are extra args passed when using the chatbot.
520 """
521 raise NotImplementedError
524class ContactCommand(_RecipientCommand[AnyContact], Generic[LegacyContactType]):
525 """
526 A command that will be avaible on a contact.
528 It implicitly requires the user to be registered and logged.
529 It is never instantiated, so all methods must be static methods.
530 Its entrypoint is the ``run()`` static method.
531 """
533 recipient_cls = LegacyContact
535 def __init_subclass__(
536 cls,
537 **kwargs: Any, # noqa:ANN401
538 ) -> None:
539 cls.recipient_cls.commands[cls.NODE] = cls # type:ignore[assignment]
540 cls.recipient_cls.commands_chat[cls.CHAT_COMMAND] = cls # type:ignore[assignment]
543class MUCCommand(_RecipientCommand[AnyMUC], Generic[LegacyMUCType]):
544 """
545 A command that will be avaible on a MUC.
547 It implicitly requires the user to be registered and logged.
548 It is never instantiated, so all methods must be static methods.
549 Its entrypoint is the ``run()`` static method.
550 """
552 recipient_cls = LegacyMUC
554 def __init_subclass__(
555 cls,
556 **kwargs: Any, # noqa:ANN401
557 ) -> None:
558 cls.recipient_cls.commands[cls.NODE] = cls
559 cls.recipient_cls.commands_chat[cls.CHAT_COMMAND] = cls
562def is_admin(jid: JidStr) -> bool:
563 return JID(jid).bare in config.ADMINS