Coverage for slidge / command / base.py: 94%
204 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-03-13 22:59 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-03-13 22:59 +0000
1from abc import ABC
2from collections.abc import Awaitable, Callable, Iterable, Sequence
3from dataclasses import dataclass, field
4from enum import Enum
5from typing import (
6 TYPE_CHECKING,
7 Any,
8 Generic,
9 TypedDict,
10 Union,
11)
13from slixmpp import JID
14from slixmpp.exceptions import XMPPError
15from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
16from slixmpp.plugins.xep_0004 import FormField as SlixFormField
17from slixmpp.types import JidStr
19from ..core import config
20from ..util.types import AnyBaseSession, FieldType, SessionType
22NODE_PREFIX = "https://slidge.im/command/core/"
24if TYPE_CHECKING:
25 from ..core.gateway import BaseGateway
26 from .categories import CommandCategory
29HandlerType = (
30 Callable[[AnyBaseSession, JID], "CommandResponseType"]
31 | Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]]
32)
34FormValues = dict[str, str | JID | bool]
37@dataclass
38class TableResult:
39 """
40 Structured data as the result of a command
41 """
43 fields: Sequence["FormField"]
44 """
45 The 'columns names' of the table.
46 """
47 items: Sequence[dict[str, str | JID]]
48 """
49 The rows of the table. Each row is a dict where keys are the fields ``var``
50 attribute.
51 """
52 description: str
53 """
54 A description of the content of the table.
55 """
57 jids_are_mucs: bool = False
59 def get_xml(self) -> SlixForm:
60 """
61 Get a slixmpp "form" (with <reported> header)to represent the data
63 :return: some XML
64 """
65 form = SlixForm() # type: ignore[no-untyped-call]
66 form["type"] = "result"
67 form["title"] = self.description
68 for f in self.fields:
69 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call]
70 for item in self.items:
71 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call]
72 return form
75@dataclass
76class SearchResult(TableResult):
77 """
78 Results of the search command (search for contacts via Jabber Search)
80 Return type of :meth:`BaseSession.search`.
81 """
83 description: str = "Contact search results"
86@dataclass
87class Confirmation(Generic[SessionType]):
88 """
89 A confirmation 'dialog'
90 """
92 prompt: str
93 """
94 The text presented to the command triggering user
95 """
96 handler: Callable[
97 [SessionType | None, JID],
98 Awaitable["CommandResponseType"],
99 ]
100 """
101 An async function that should return a ResponseType
102 """
103 success: str | None = None
104 """
105 Text in case of success, used if handler does not return anything
106 """
107 handler_args: Iterable[Any] = field(default_factory=list)
108 """
109 arguments passed to the handler
110 """
111 handler_kwargs: dict[str, Any] = field(default_factory=dict)
112 """
113 keyword arguments passed to the handler
114 """
116 def get_form(self) -> SlixForm:
117 """
118 Get the slixmpp form
120 :return: some xml
121 """
122 form = SlixForm() # type: ignore[no-untyped-call]
123 form["type"] = "form"
124 form["title"] = self.prompt
125 form.append(
126 FormField(
127 "confirm", type="boolean", value="true", label="Confirm"
128 ).get_xml()
129 )
130 return form
133@dataclass
134class Form(Generic[SessionType]):
135 """
136 A form, to request user input
137 """
139 title: str
140 instructions: str
141 fields: Sequence["FormField"]
142 handler: Callable[
143 [FormValues, SessionType | None, JID],
144 Awaitable["CommandResponseType"],
145 ]
146 handler_args: Iterable[Any] = field(default_factory=list)
147 handler_kwargs: dict[str, Any] = field(default_factory=dict)
148 timeout_handler: Callable[[], None] | None = None
150 def get_values(
151 self, slix_form: SlixForm
152 ) -> dict[str, list[str] | list[JID] | str | JID | bool | None]:
153 """
154 Parse form submission
156 :param slix_form: the xml received as the submission of a form
157 :return: A dict where keys=field.var and values are either strings
158 or JIDs (if field.type=jid-single)
159 """
160 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call]
161 values = {}
162 for f in self.fields:
163 values[f.var] = f.validate(str_values.get(f.var))
164 return values
166 def get_xml(self) -> SlixForm:
167 """
168 Get the slixmpp "form"
170 :return: some XML
171 """
172 form = SlixForm() # type: ignore[no-untyped-call]
173 form["type"] = "form"
174 form["title"] = self.title
175 form["instructions"] = self.instructions
176 for fi in self.fields:
177 form.append(fi.get_xml())
178 return form
181class CommandAccess(int, Enum):
182 """
183 Defines who can access a given Command
184 """
186 ADMIN_ONLY = 0
187 USER = 1
188 USER_LOGGED = 2
189 USER_NON_LOGGED = 3
190 NON_USER = 4
191 ANY = 5
194class Option(TypedDict):
195 """
196 Options to be used for ``FormField``s of type ``list-*``
197 """
199 label: str
200 value: str
203# TODO: support forms validation XEP-0122
204@dataclass
205class FormField:
206 """
207 Represents a field of the form that a user will see when registering to the gateway
208 via their XMPP client.
209 """
211 var: str = ""
212 """
213 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form`
214 """
215 label: str | None = None
216 """Description of the field that the user will see"""
217 required: bool = False
218 """Whether this field is mandatory or not"""
219 private: bool = False
220 """
221 For sensitive info that should not be displayed on screen while the user types.
222 Forces field_type to "text-private"
223 """
224 type: FieldType = "text-single"
225 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_"""
226 value: str = ""
227 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription"""
228 options: list[Option] | None = None
230 image_url: str | None = None
231 """An image associated to this field, eg, a QR code"""
233 def __post_init__(self) -> None:
234 if self.private:
235 self.type = "text-private"
237 def __acceptable_options(self) -> list[str]:
238 if self.options is None:
239 raise RuntimeError
240 return [x["value"] for x in self.options]
242 def validate(
243 self, value: str | list[str] | None
244 ) -> list[str] | list[JID] | str | JID | bool | None:
245 """
246 Raise appropriate XMPPError if a given value is valid for this field
248 :param value: The value to test
249 :return: The same value OR a JID if ``self.type=jid-single``
250 """
251 if isinstance(value, list) and not self.type.endswith("multi"):
252 raise XMPPError("not-acceptable", "A single value was expected")
254 if self.type in ("list-multi", "jid-multi", "text-multi"):
255 if not value:
256 value = []
257 if isinstance(value, list):
258 if self.type == "text-multi":
259 return value
260 return self.__validate_list_multi(value)
261 else:
262 raise XMPPError("not-acceptable", "Multiple values was expected")
264 assert isinstance(value, (str, bool, JID)) or value is None
266 if self.required and value is None:
267 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'")
269 if value is None:
270 return None
272 if self.type == "jid-single":
273 try:
274 return JID(value)
275 except ValueError:
276 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'")
278 elif self.type == "list-single":
279 if value not in self.__acceptable_options():
280 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'")
282 elif self.type == "boolean":
283 return value.lower() in ("1", "true") if isinstance(value, str) else value
285 return value
287 def __validate_list_multi(self, value: list[str]) -> list[str] | list[JID]:
288 for v in value:
289 if v not in self.__acceptable_options():
290 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'")
291 if self.type == "list-multi":
292 return value
293 return [JID(v) for v in value]
295 def get_xml(self) -> SlixFormField:
296 """
297 Get the field in slixmpp format
299 :return: some XML
300 """
301 f = SlixFormField()
302 f["var"] = self.var
303 f["label"] = self.label
304 f["required"] = self.required
305 f["type"] = self.type
306 if self.options:
307 for o in self.options:
308 f.add_option(**o) # type: ignore[no-untyped-call]
309 f["value"] = self.value
310 if self.image_url:
311 f["media"].add_uri(self.image_url, itype="image/png")
312 return f
315CommandResponseType = TableResult | Confirmation | Form | str | None
318class Command(ABC, Generic[SessionType]):
319 """
320 Abstract base class to implement gateway commands (chatbot and ad-hoc)
321 """
323 NAME: str = NotImplemented
324 """
325 Friendly name of the command, eg: "do something with stuff"
326 """
327 HELP: str = NotImplemented
328 """
329 Long description of what the command does
330 """
331 NODE: str = NotImplemented
332 """
333 Name of the node used for ad-hoc commands
334 """
335 CHAT_COMMAND: str = NotImplemented
336 """
337 Text to send to the gateway to trigger the command via a message
338 """
340 ACCESS: "CommandAccess" = NotImplemented
341 """
342 Who can use this command
343 """
345 CATEGORY: Union[str, "CommandCategory"] | None = None
346 """
347 If used, the command will be under this top-level category.
348 Use the same string for several commands to group them.
349 This hierarchy only used for the adhoc interface, not the chat command
350 interface.
351 """
353 subclasses = list[type["Command"]]()
355 def __init__(self, xmpp: "BaseGateway") -> None:
356 self.xmpp = xmpp
358 def __init_subclass__(cls, **kwargs: Any) -> None:
359 # store subclasses so subclassing is enough for the command to be
360 # picked up by slidge
361 cls.subclasses.append(cls)
363 async def run(
364 self,
365 session: SessionType | None,
366 ifrom: JID,
367 *args: str,
368 ) -> CommandResponseType:
369 """
370 Entry point of the command
372 :param session: If triggered by a registered user, its slidge Session
373 :param ifrom: JID of the command-triggering entity
374 :param args: When triggered via chatbot type message, additional words
375 after the CHAT_COMMAND string was passed
377 :return: Either a TableResult, a Form, a Confirmation, a text, or None
378 """
379 raise XMPPError("feature-not-implemented")
381 def _get_session(self, jid: JID) -> SessionType | None:
382 return self.xmpp.get_session_from_jid(jid)
384 def __can_use_command(self, jid: JID) -> bool:
385 j = jid.bare
386 return bool(self.xmpp.jid_validator.match(j) or j in config.ADMINS)
388 def raise_if_not_authorized(
389 self,
390 jid: JID,
391 fetch_session: bool = True,
392 session: SessionType | None = None,
393 ) -> SessionType | None:
394 """
395 Raise an appropriate error is jid is not authorized to use the command
397 :param jid: jid of the entity trying to access the command
398 :param fetch_session:
399 :param session:
401 :return:session of JID if it exists
402 """
403 if not self.__can_use_command(jid):
404 raise XMPPError(
405 "bad-request", "Your JID is not allowed to use this gateway."
406 )
407 if fetch_session:
408 session = self._get_session(jid)
410 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid):
411 raise XMPPError("not-authorized")
412 elif self.ACCESS == CommandAccess.NON_USER and session is not None:
413 raise XMPPError(
414 "bad-request", "This is only available for non-users. Unregister first."
415 )
416 elif self.ACCESS == CommandAccess.USER and session is None:
417 raise XMPPError(
418 "forbidden",
419 "This is only available for users that are registered to this gateway",
420 )
421 elif self.ACCESS == CommandAccess.USER_NON_LOGGED:
422 if session is None or session.logged:
423 raise XMPPError(
424 "forbidden",
425 (
426 "This is only available for users that are not logged to the"
427 " legacy service"
428 ),
429 )
430 elif self.ACCESS == CommandAccess.USER_LOGGED:
431 if session is None or not session.logged:
432 raise XMPPError(
433 "forbidden",
434 (
435 "This is only available when you are logged in to the legacy"
436 " service"
437 ),
438 )
439 return session
442def is_admin(jid: JidStr) -> bool:
443 return JID(jid).bare in config.ADMINS