Coverage for slidge/command/base.py: 93%
205 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1from abc import ABC
2from dataclasses import dataclass, field
3from enum import Enum
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Awaitable,
8 Callable,
9 Iterable,
10 Optional,
11 Sequence,
12 Type,
13 TypedDict,
14 Union,
15)
17from slixmpp import JID
18from slixmpp.exceptions import XMPPError
19from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
20from slixmpp.plugins.xep_0004 import FormField as SlixFormField
21from slixmpp.types import JidStr
23from ..core import config
24from ..db.models import GatewayUser
25from ..util.types import AnyBaseSession, FieldType
27NODE_PREFIX = "https://slidge.im/command/core/"
29if TYPE_CHECKING:
30 from ..core.gateway import BaseGateway
31 from ..core.session import BaseSession
32 from .categories import CommandCategory
35HandlerType = Union[
36 Callable[[AnyBaseSession, JID], "CommandResponseType"],
37 Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]],
38]
40FormValues = dict[str, Union[str, JID, bool]]
43FormHandlerType = Callable[
44 [FormValues, AnyBaseSession, JID],
45 Awaitable["CommandResponseType"],
46]
48ConfirmationHandlerType = Callable[
49 [Optional[AnyBaseSession], JID], Awaitable["CommandResponseType"]
50]
53@dataclass
54class TableResult:
55 """
56 Structured data as the result of a command
57 """
59 fields: Sequence["FormField"]
60 """
61 The 'columns names' of the table.
62 """
63 items: Sequence[dict[str, Union[str, JID]]]
64 """
65 The rows of the table. Each row is a dict where keys are the fields ``var``
66 attribute.
67 """
68 description: str
69 """
70 A description of the content of the table.
71 """
73 jids_are_mucs: bool = False
75 def get_xml(self) -> SlixForm:
76 """
77 Get a slixmpp "form" (with <reported> header)to represent the data
79 :return: some XML
80 """
81 form = SlixForm() # type: ignore[no-untyped-call]
82 form["type"] = "result"
83 form["title"] = self.description
84 for f in self.fields:
85 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call]
86 for item in self.items:
87 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call]
88 return form
91@dataclass
92class SearchResult(TableResult):
93 """
94 Results of the search command (search for contacts via Jabber Search)
96 Return type of :meth:`BaseSession.search`.
97 """
99 description: str = "Contact search results"
102@dataclass
103class Confirmation:
104 """
105 A confirmation 'dialog'
106 """
108 prompt: str
109 """
110 The text presented to the command triggering user
111 """
112 handler: ConfirmationHandlerType
113 """
114 An async function that should return a ResponseType
115 """
116 success: Optional[str] = None
117 """
118 Text in case of success, used if handler does not return anything
119 """
120 handler_args: Iterable[Any] = field(default_factory=list)
121 """
122 arguments passed to the handler
123 """
124 handler_kwargs: dict[str, Any] = field(default_factory=dict)
125 """
126 keyword arguments passed to the handler
127 """
129 def get_form(self) -> SlixForm:
130 """
131 Get the slixmpp form
133 :return: some xml
134 """
135 form = SlixForm() # type: ignore[no-untyped-call]
136 form["type"] = "form"
137 form["title"] = self.prompt
138 form.append(
139 FormField(
140 "confirm", type="boolean", value="true", label="Confirm"
141 ).get_xml()
142 )
143 return form
146@dataclass
147class Form:
148 """
149 A form, to request user input
150 """
152 title: str
153 instructions: str
154 fields: Sequence["FormField"]
155 handler: FormHandlerType
156 handler_args: Iterable[Any] = field(default_factory=list)
157 handler_kwargs: dict[str, Any] = field(default_factory=dict)
159 def get_values(
160 self, slix_form: SlixForm
161 ) -> dict[str, Union[list[str], list[JID], str, JID, bool, None]]:
162 """
163 Parse form submission
165 :param slix_form: the xml received as the submission of a form
166 :return: A dict where keys=field.var and values are either strings
167 or JIDs (if field.type=jid-single)
168 """
169 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call]
170 values = {}
171 for f in self.fields:
172 values[f.var] = f.validate(str_values.get(f.var))
173 return values
175 def get_xml(self) -> SlixForm:
176 """
177 Get the slixmpp "form"
179 :return: some XML
180 """
181 form = SlixForm() # type: ignore[no-untyped-call]
182 form["type"] = "form"
183 form["title"] = self.title
184 form["instructions"] = self.instructions
185 for fi in self.fields:
186 form.append(fi.get_xml())
187 return form
190class CommandAccess(int, Enum):
191 """
192 Defines who can access a given Command
193 """
195 ADMIN_ONLY = 0
196 USER = 1
197 USER_LOGGED = 2
198 USER_NON_LOGGED = 3
199 NON_USER = 4
200 ANY = 5
203class Option(TypedDict):
204 """
205 Options to be used for ``FormField``s of type ``list-*``
206 """
208 label: str
209 value: str
212# TODO: support forms validation XEP-0122
213@dataclass
214class FormField:
215 """
216 Represents a field of the form that a user will see when registering to the gateway
217 via their XMPP client.
218 """
220 var: str = ""
221 """
222 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form`
223 """
224 label: Optional[str] = None
225 """Description of the field that the user will see"""
226 required: bool = False
227 """Whether this field is mandatory or not"""
228 private: bool = False
229 """
230 For sensitive info that should not be displayed on screen while the user types.
231 Forces field_type to "text-private"
232 """
233 type: FieldType = "text-single"
234 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_"""
235 value: str = ""
236 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription"""
237 options: Optional[list[Option]] = None
239 image_url: Optional[str] = None
240 """An image associated to this field, eg, a QR code"""
242 def __post_init__(self) -> None:
243 if self.private:
244 self.type = "text-private"
246 def __acceptable_options(self) -> list[str]:
247 if self.options is None:
248 raise RuntimeError
249 return [x["value"] for x in self.options]
251 def validate(
252 self, value: Optional[Union[str, list[str]]]
253 ) -> Union[list[str], list[JID], str, JID, bool, None]:
254 """
255 Raise appropriate XMPPError if a given value is valid for this field
257 :param value: The value to test
258 :return: The same value OR a JID if ``self.type=jid-single``
259 """
260 if isinstance(value, list) and not self.type.endswith("multi"):
261 raise XMPPError("not-acceptable", "A single value was expected")
263 if self.type in ("list-multi", "jid-multi"):
264 if not value:
265 value = []
266 if isinstance(value, list):
267 return self.__validate_list_multi(value)
268 else:
269 raise XMPPError("not-acceptable", "Multiple values was expected")
271 assert isinstance(value, (str, bool, JID)) or value is None
273 if self.required and value is None:
274 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'")
276 if value is None:
277 return None
279 if self.type == "jid-single":
280 try:
281 return JID(value)
282 except ValueError:
283 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'")
285 elif self.type == "list-single":
286 if value not in self.__acceptable_options():
287 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'")
289 elif self.type == "boolean":
290 return value.lower() in ("1", "true") if isinstance(value, str) else value
292 return value
294 def __validate_list_multi(self, value: list[str]) -> Union[list[str], list[JID]]:
295 # COMPAT: all the "if v" and "if not v" are workarounds for https://codeberg.org/slidge/slidge/issues/43
296 # They should be reverted once the bug is fixed upstream, cf https://soprani.ca/todo/390
297 for v in value:
298 if v not in self.__acceptable_options():
299 if not v:
300 continue
301 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'")
302 if self.type == "list-multi":
303 return [v for v in value if v]
304 return [JID(v) for v in value if v]
306 def get_xml(self) -> SlixFormField:
307 """
308 Get the field in slixmpp format
310 :return: some XML
311 """
312 f = SlixFormField()
313 f["var"] = self.var
314 f["label"] = self.label
315 f["required"] = self.required
316 f["type"] = self.type
317 if self.options:
318 for o in self.options:
319 f.add_option(**o) # type: ignore[no-untyped-call]
320 f["value"] = self.value
321 if self.image_url:
322 f["media"].add_uri(self.image_url, itype="image/png")
323 return f
326CommandResponseType = Union[TableResult, Confirmation, Form, str, None]
329class Command(ABC):
330 """
331 Abstract base class to implement gateway commands (chatbot and ad-hoc)
332 """
334 NAME: str = NotImplemented
335 """
336 Friendly name of the command, eg: "do something with stuff"
337 """
338 HELP: str = NotImplemented
339 """
340 Long description of what the command does
341 """
342 NODE: str = NotImplemented
343 """
344 Name of the node used for ad-hoc commands
345 """
346 CHAT_COMMAND: str = NotImplemented
347 """
348 Text to send to the gateway to trigger the command via a message
349 """
351 ACCESS: "CommandAccess" = NotImplemented
352 """
353 Who can use this command
354 """
356 CATEGORY: Optional[Union[str, "CommandCategory"]] = None
357 """
358 If used, the command will be under this top-level category.
359 Use the same string for several commands to group them.
360 This hierarchy only used for the adhoc interface, not the chat command
361 interface.
362 """
364 subclasses = list[Type["Command"]]()
366 def __init__(self, xmpp: "BaseGateway") -> None:
367 self.xmpp = xmpp
369 def __init_subclass__(cls, **kwargs: Any) -> None:
370 # store subclasses so subclassing is enough for the command to be
371 # picked up by slidge
372 cls.subclasses.append(cls)
374 async def run(
375 self, session: Optional["BaseSession[Any, Any]"], ifrom: JID, *args: str
376 ) -> CommandResponseType:
377 """
378 Entry point of the command
380 :param session: If triggered by a registered user, its slidge Session
381 :param ifrom: JID of the command-triggering entity
382 :param args: When triggered via chatbot type message, additional words
383 after the CHAT_COMMAND string was passed
385 :return: Either a TableResult, a Form, a Confirmation, a text, or None
386 """
387 raise XMPPError("feature-not-implemented")
389 def _get_session(self, jid: JID) -> Optional["BaseSession[Any, Any]"]:
390 return self.xmpp.get_session_from_jid(jid)
392 def __can_use_command(self, jid: JID):
393 j = jid.bare
394 return self.xmpp.jid_validator.match(j) or j in config.ADMINS
396 def raise_if_not_authorized(
397 self,
398 jid: JID,
399 fetch_session: bool = True,
400 session: Optional["BaseSession[Any, Any]"] = None,
401 ) -> Optional["BaseSession[Any, Any]"]:
402 """
403 Raise an appropriate error is jid is not authorized to use the command
405 :param jid: jid of the entity trying to access the command
406 :param fetch_session:
407 :param session:
409 :return:session of JID if it exists
410 """
411 if not self.__can_use_command(jid):
412 raise XMPPError(
413 "bad-request", "Your JID is not allowed to use this gateway."
414 )
415 if fetch_session:
416 session = self._get_session(jid)
418 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid):
419 raise XMPPError("not-authorized")
420 elif self.ACCESS == CommandAccess.NON_USER and session is not None:
421 raise XMPPError(
422 "bad-request", "This is only available for non-users. Unregister first."
423 )
424 elif self.ACCESS == CommandAccess.USER and session is None:
425 raise XMPPError(
426 "forbidden",
427 "This is only available for users that are registered to this gateway",
428 )
429 elif self.ACCESS == CommandAccess.USER_NON_LOGGED:
430 if session is None or session.logged:
431 raise XMPPError(
432 "forbidden",
433 (
434 "This is only available for users that are not logged to the"
435 " legacy service"
436 ),
437 )
438 elif self.ACCESS == CommandAccess.USER_LOGGED:
439 if session is None or not session.logged:
440 raise XMPPError(
441 "forbidden",
442 (
443 "This is only available when you are logged in to the legacy"
444 " service"
445 ),
446 )
447 return session
450def is_admin(jid: JidStr) -> bool:
451 return JID(jid).bare in config.ADMINS