Coverage for slidge / command / adhoc.py: 86%
196 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1import asyncio
2import inspect
3import logging
4from collections.abc import Awaitable, Callable
5from functools import partial
6from typing import TYPE_CHECKING, Any, Optional, ParamSpec, TypeVar, cast
8from slixmpp import JID, Iq
9from slixmpp.exceptions import XMPPError
10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems
12from slixmpp.plugins.xep_0050.adhoc import CommandType
14from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession
16from ..core import config
17from ..util.util import strip_leading_emoji
18from . import Command, CommandResponseType, Confirmation, Form, TableResult
19from .base import (
20 CommandResponseRecipientType,
21 CommandResponseSessionType,
22 ContactCommand,
23 FormField,
24 MUCCommand,
25)
26from .categories import CommandCategory
28if TYPE_CHECKING:
29 from ..core.gateway import BaseGateway
30 from ..core.session import BaseSession
33AdhocSessionType = dict[str, Any]
34T = TypeVar("T")
35P = ParamSpec("P")
38class AdhocProvider:
39 """
40 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and
41 untyped dict values than slixmpp.
42 """
44 FORM_TIMEOUT = 120 # seconds
46 def __init__(self, xmpp: "BaseGateway") -> None:
47 self.xmpp = xmpp
48 self._commands = dict[str, Command[AnySession]]()
49 self._categories = dict[str, list[Command[AnySession]]]()
50 xmpp.plugin["xep_0030"].set_node_handler(
51 "get_items",
52 jid=xmpp.boundjid,
53 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace,
54 handler=self.get_items,
55 )
56 self.xmpp.plugin["xep_0050"].api.register(self.__get_command, "get_command")
57 self.__timeouts: dict[str, asyncio.TimerHandle] = {}
59 async def __get_command(
60 self,
61 jid: JID | str | None = None,
62 node: str | None = None,
63 ifrom: JID | None = None,
64 args: None = None,
65 ) -> CommandType | None:
66 if node is None:
67 return None
68 if jid is None:
69 jid = self.xmpp.boundjid.bare
70 if not isinstance(jid, JID):
71 jid = JID(jid)
72 if jid == self.xmpp.boundjid.bare:
73 return self.xmpp.plugin["xep_0050"].commands.get((jid.full, node))
74 if ifrom is None:
75 raise XMPPError("undefined-condition")
76 session = self.xmpp.get_session_from_jid(ifrom)
77 if session is None:
78 raise XMPPError("subscription-required")
79 recipient = await session.get_contact_or_group_or_participant(jid)
80 if recipient is None or recipient.is_participant:
81 return None
82 command = recipient.commands.get(node)
83 if command is None:
84 return None
85 name = strip_leading_emoji_if_needed(command.NAME)
86 handler = partial(self.__wrap_initial_handler, command, recipient=recipient)
87 return name, handler, None, None # type:ignore
89 async def __wrap_initial_handler(
90 self,
91 command: Command[AnySession]
92 | type[ContactCommand[AnyContact]]
93 | type[MUCCommand[AnyMUC]],
94 iq: Iq,
95 adhoc_session: AdhocSessionType,
96 recipient: AnyRecipient | None = None,
97 ) -> AdhocSessionType:
98 ifrom = iq.get_from()
99 if recipient is None:
100 cmd = cast(Command[AnySession], command)
101 session = cmd.raise_if_not_authorized(ifrom)
102 result1: CommandResponseSessionType[Any] = await self.__wrap_handler(
103 cmd.run, session, ifrom
104 )
105 return await self.__handle_result(session, result1, adhoc_session)
106 else:
107 cmd2 = cast(
108 type[ContactCommand[AnyContact]] | type[MUCCommand[AnyMUC]], command
109 )
110 result2: CommandResponseRecipientType[Any] = await self.__wrap_handler(
111 cmd2.run, # type:ignore[arg-type]
112 recipient,
113 )
114 return await self.__handle_result(
115 recipient.session, result2, adhoc_session, recipient
116 )
118 async def __handle_category_list(
119 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType
120 ) -> AdhocSessionType:
121 try:
122 session = self.xmpp.get_session_from_stanza(iq)
123 except XMPPError:
124 session = None
125 commands: dict[str, Command[AnySession]] = {}
126 for command in self._categories[category.node]:
127 try:
128 command.raise_if_not_authorized(iq.get_from())
129 except XMPPError:
130 continue
131 commands[command.NODE] = command
132 if len(commands) == 0:
133 raise XMPPError(
134 "not-authorized", "There is no command you can run in this category"
135 )
136 return await self.__handle_result(
137 session,
138 Form(
139 category.name,
140 "",
141 [
142 FormField(
143 var="command",
144 label="Command",
145 type="list-single",
146 options=[
147 {
148 "label": strip_leading_emoji_if_needed(command.NAME),
149 "value": command.NODE,
150 }
151 for command in commands.values()
152 ],
153 )
154 ],
155 partial(self.__handle_category_choice, commands),
156 ),
157 adhoc_session,
158 )
160 async def __handle_category_choice(
161 self,
162 commands: dict[str, Command[AnySession]],
163 form_values: dict[str, str],
164 session: "BaseSession[Any, Any]",
165 jid: JID,
166 ) -> CommandResponseSessionType[Any]:
167 command = commands[form_values["command"]]
168 result: CommandResponseSessionType[Any] = await self.__wrap_handler(
169 command.run, session, jid
170 )
171 return result
173 async def __handle_result(
174 self,
175 session: Optional["BaseSession[Any, Any]"],
176 result: CommandResponseType,
177 adhoc_session: AdhocSessionType,
178 recipient: AnyRecipient | None = None,
179 ) -> AdhocSessionType:
180 if isinstance(result, str) or result is None:
181 adhoc_session["has_next"] = False
182 adhoc_session["next"] = None
183 adhoc_session["payload"] = None
184 adhoc_session["notes"] = [("info", result or "Success!")]
185 return adhoc_session
187 if isinstance(result, Form):
188 adhoc_session["next"] = partial(
189 self.__wrap_form_handler, session, result, recipient
190 )
191 adhoc_session["has_next"] = True
192 adhoc_session["payload"] = result.get_xml()
193 if result.timeout_handler is not None:
194 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later(
195 self.FORM_TIMEOUT,
196 partial(
197 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"]
198 ),
199 )
200 return adhoc_session
202 if isinstance(result, Confirmation):
203 adhoc_session["next"] = partial(
204 self.__wrap_confirmation, session, result, recipient
205 )
206 adhoc_session["has_next"] = True
207 adhoc_session["payload"] = result.get_form()
208 return adhoc_session
210 if isinstance(result, TableResult):
211 adhoc_session["next"] = None
212 adhoc_session["has_next"] = False
213 adhoc_session["payload"] = result.get_xml()
214 return adhoc_session
216 raise XMPPError("internal-server-error", text="OOPS!")
218 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None:
219 try:
220 del self.xmpp.plugin["xep_0050"].sessions[session_id]
221 except KeyError:
222 log.error("Timeout but session could not be found: %s", session_id)
223 handler()
225 @staticmethod
226 async def __wrap_handler(
227 f: Callable[P, Awaitable[T] | T],
228 *a: P.args,
229 **k: P.kwargs,
230 ) -> T:
231 try:
232 if inspect.iscoroutinefunction(f):
233 return await f(*a, **k) # type:ignore[no-any-return]
234 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
235 return await f(*a, **k) # type:ignore[misc,no-any-return]
236 else:
237 return f(*a, **k) # type:ignore[return-value]
238 except XMPPError:
239 raise
240 except Exception as e:
241 log.debug("Exception in %s", f, exc_info=e)
242 raise XMPPError("internal-server-error", text=str(e))
244 async def __wrap_form_handler(
245 self,
246 session: Optional["BaseSession[Any, Any]"],
247 result: Form,
248 recipient: AnyRecipient | None,
249 form: SlixForm,
250 adhoc_session: AdhocSessionType,
251 ) -> AdhocSessionType:
252 timer = self.__timeouts.pop(adhoc_session["id"], None)
253 if timer is not None:
254 print("canceled", adhoc_session["id"])
255 timer.cancel()
256 form_values = result.get_values(form)
257 if recipient is None:
258 new_result = await self.__wrap_handler(
259 result.handler,
260 form_values,
261 session,
262 adhoc_session["from"],
263 *result.handler_args,
264 **result.handler_kwargs,
265 )
266 else:
267 new_result = await self.__wrap_handler(
268 result.handler,
269 recipient,
270 form_values,
271 *result.handler_args,
272 **result.handler_kwargs,
273 )
274 return await self.__handle_result(session, new_result, adhoc_session, recipient)
276 async def __wrap_confirmation(
277 self,
278 session: Optional["BaseSession[Any, Any]"],
279 confirmation: Confirmation,
280 recipient: AnyRecipient | None,
281 form: SlixForm,
282 adhoc_session: AdhocSessionType,
283 ) -> AdhocSessionType:
284 if form.get_values().get("confirm"):
285 if recipient is None:
286 result = await self.__wrap_handler(
287 confirmation.handler,
288 session,
289 adhoc_session["from"],
290 *confirmation.handler_args,
291 **confirmation.handler_kwargs,
292 )
293 if confirmation.success:
294 result = confirmation.success
295 else:
296 result = await self.__wrap_handler(
297 confirmation.handler,
298 recipient,
299 *confirmation.handler_args,
300 **confirmation.handler_kwargs,
301 )
302 else:
303 result = "You canceled the operation"
305 return await self.__handle_result(session, result, adhoc_session, recipient)
307 def register(self, command: Command[AnySession], jid: JID | None = None) -> None:
308 """
309 Register a command as a adhoc command.
311 this does not need to be called manually, ``BaseGateway`` takes care of
312 that.
314 :param command:
315 :param jid:
316 """
317 if jid is None:
318 jid = self.xmpp.boundjid
319 elif not isinstance(jid, JID):
320 jid = JID(jid)
322 if (category := command.CATEGORY) is None:
323 if command.NODE in self._commands:
324 raise RuntimeError(
325 "There is already a command for the node '%s'", command.NODE
326 )
327 self._commands[command.NODE] = command
328 self.xmpp.plugin["xep_0050"].add_command(
329 jid=jid,
330 node=command.NODE,
331 name=strip_leading_emoji_if_needed(command.NAME),
332 handler=partial(self.__wrap_initial_handler, command),
333 )
334 else:
335 if isinstance(category, str):
336 category = CommandCategory(category, category)
337 node = category.node
338 name = category.name
339 if node not in self._categories:
340 self._categories[node] = list[Command[AnySession]]()
341 self.xmpp.plugin["xep_0050"].add_command(
342 jid=jid,
343 node=node,
344 name=strip_leading_emoji_if_needed(name),
345 handler=partial(self.__handle_category_list, category),
346 )
347 self._categories[node].append(command)
349 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems:
350 """
351 Get items for a disco query
353 :param jid: the entity that should return its items
354 :param node: which command node is requested
355 :param iq: the disco query IQ
356 :return: commands accessible to the given JID will be listed
357 """
358 ifrom = iq.get_from()
359 ifrom_str = str(ifrom)
360 if (
361 not self.xmpp.jid_validator.match(ifrom_str)
362 and ifrom_str not in config.ADMINS
363 ):
364 raise XMPPError(
365 "forbidden",
366 "You are not authorized to execute adhoc commands on this gateway. "
367 "If this is unexpected, ask your administrator to verify that "
368 "'user-jid-validator' is correctly set in slidge's configuration.",
369 )
371 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None)
372 log.debug("Static items: %r", all_items)
373 if not all_items:
374 return DiscoItems()
376 session = self.xmpp.get_session_from_jid(ifrom)
378 filtered_items = DiscoItems()
379 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace
380 for item in all_items:
381 authorized = True
382 if item["node"] in self._categories:
383 for command in self._categories[item["node"]]:
384 try:
385 command.raise_if_not_authorized(
386 ifrom, fetch_session=False, session=session
387 )
388 except XMPPError:
389 authorized = False
390 else:
391 authorized = True
392 break
393 else:
394 try:
395 self._commands[item["node"]].raise_if_not_authorized(
396 ifrom, fetch_session=False, session=session
397 )
398 except XMPPError:
399 authorized = False
401 if authorized:
402 filtered_items.append(item)
404 return filtered_items
407def strip_leading_emoji_if_needed(text: str) -> str:
408 if config.STRIP_LEADING_EMOJI_ADHOC:
409 return strip_leading_emoji(text)
410 return text
413log = logging.getLogger(__name__)