Coverage for slidge / command / adhoc.py: 85%
202 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
1import asyncio
2import inspect
3import logging
4from collections.abc import Awaitable, Callable
5from functools import partial
6from typing import TYPE_CHECKING, Any, Generic, Optional, ParamSpec, TypeVar, cast
8from slixmpp import JID, Iq
9from slixmpp.exceptions import XMPPError
10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems
12from slixmpp.plugins.xep_0050.adhoc import CommandType
14from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession
16from ..core import config
17from ..util.util import strip_leading_emoji
18from . import Command, CommandResponseType, Confirmation, Form, TableResult
19from .base import (
20 CommandResponseRecipientType,
21 CommandResponseSessionType,
22 ContactCommand,
23 FormField,
24 MUCCommand,
25)
26from .categories import CommandCategory
28if TYPE_CHECKING:
29 from ..core.gateway import BaseGateway
30 from ..core.session import BaseSession
33GatewayType = TypeVar("GatewayType", bound="BaseGateway[Any]")
34AdhocSessionType = dict[str, Any]
35T = TypeVar("T")
36P = ParamSpec("P")
39class AdhocProvider(Generic[GatewayType]):
40 """
41 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and
42 untyped dict values than slixmpp.
43 """
45 FORM_TIMEOUT = 120 # seconds
46 xmpp: GatewayType
48 def __init__(self, xmpp: GatewayType) -> None:
49 self.xmpp = xmpp
50 self._commands = dict[str, Command[AnySession]]()
51 self._categories = dict[str, list[Command[AnySession]]]()
52 xmpp.plugin["xep_0030"].set_node_handler(
53 "get_items",
54 jid=xmpp.boundjid,
55 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace,
56 handler=self.get_items,
57 )
58 self.xmpp.plugin["xep_0050"].api.register(self.__get_command, "get_command")
59 self.__timeouts: dict[str, asyncio.TimerHandle] = {}
61 async def __get_command(
62 self,
63 jid: JID | str | None = None,
64 node: str | None = None,
65 ifrom: JID | None = None,
66 args: None = None,
67 ) -> CommandType | None:
68 if node is None:
69 return None
70 if jid is None:
71 jid = self.xmpp.boundjid.bare
72 if not isinstance(jid, JID):
73 jid = JID(jid)
74 if jid == self.xmpp.boundjid.bare:
75 return self.xmpp.plugin["xep_0050"].commands.get((jid.full, node))
76 if ifrom is None:
77 raise XMPPError("undefined-condition")
78 session = self.xmpp.get_session_from_jid(ifrom)
79 if session is None:
80 raise XMPPError("subscription-required")
81 recipient = await session.get_contact_or_group_or_participant(jid)
82 if recipient is None:
83 return None
84 if recipient.is_participant:
85 if recipient.contact is None:
86 return None
87 recipient = recipient.contact
88 command = recipient.commands.get(node)
89 if command is None:
90 return None
91 name = strip_leading_emoji_if_needed(command.NAME)
92 handler = partial(self.__wrap_initial_handler, command, recipient=recipient)
93 return name, handler, None, None # type:ignore
95 async def __wrap_initial_handler(
96 self,
97 command: Command[AnySession]
98 | type[ContactCommand[AnyContact]]
99 | type[MUCCommand[AnyMUC]],
100 iq: Iq,
101 adhoc_session: AdhocSessionType,
102 recipient: AnyRecipient | None = None,
103 ) -> AdhocSessionType:
104 ifrom = iq.get_from()
105 if recipient is None:
106 cmd = cast(Command[AnySession], command)
107 session = cmd.raise_if_not_authorized(ifrom)
108 result1: CommandResponseSessionType[Any] = await self.__wrap_handler(
109 cmd.run, session, ifrom
110 )
111 return await self.__handle_result(session, result1, adhoc_session)
112 else:
113 cmd2 = cast(
114 type[ContactCommand[AnyContact]] | type[MUCCommand[AnyMUC]], command
115 )
116 result2: CommandResponseRecipientType[Any] = await self.__wrap_handler(
117 cmd2.run, # type:ignore[arg-type]
118 recipient,
119 )
120 return await self.__handle_result(
121 recipient.session, result2, adhoc_session, recipient
122 )
124 async def __handle_category_list(
125 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType
126 ) -> AdhocSessionType:
127 try:
128 session = self.xmpp.get_session_from_stanza(iq)
129 except XMPPError:
130 session = None
131 commands: dict[str, Command[AnySession]] = {}
132 for command in self._categories[category.node]:
133 try:
134 command.raise_if_not_authorized(iq.get_from())
135 except XMPPError:
136 continue
137 commands[command.NODE] = command
138 if len(commands) == 0:
139 raise XMPPError(
140 "not-authorized", "There is no command you can run in this category"
141 )
142 return await self.__handle_result(
143 session,
144 Form(
145 category.name,
146 "",
147 [
148 FormField(
149 var="command",
150 label="Command",
151 type="list-single",
152 options=[
153 {
154 "label": strip_leading_emoji_if_needed(command.NAME),
155 "value": command.NODE,
156 }
157 for command in commands.values()
158 ],
159 )
160 ],
161 partial(self.__handle_category_choice, commands),
162 ),
163 adhoc_session,
164 )
166 async def __handle_category_choice(
167 self,
168 commands: dict[str, Command[AnySession]],
169 form_values: dict[str, str],
170 session: "BaseSession[Any, Any]",
171 jid: JID,
172 ) -> CommandResponseSessionType[Any]:
173 command = commands[form_values["command"]]
174 result: CommandResponseSessionType[Any] = await self.__wrap_handler(
175 command.run, session, jid
176 )
177 return result
179 async def __handle_result(
180 self,
181 session: Optional["BaseSession[Any, Any]"],
182 result: CommandResponseType,
183 adhoc_session: AdhocSessionType,
184 recipient: AnyRecipient | None = None,
185 ) -> AdhocSessionType:
186 if isinstance(result, str) or result is None:
187 adhoc_session["has_next"] = False
188 adhoc_session["next"] = None
189 adhoc_session["payload"] = None
190 adhoc_session["notes"] = [("info", result or "Success!")]
191 return adhoc_session
193 if isinstance(result, Form):
194 adhoc_session["next"] = partial(
195 self.__wrap_form_handler, session, result, recipient
196 )
197 adhoc_session["has_next"] = True
198 adhoc_session["payload"] = result.get_xml()
199 if result.timeout_handler is not None:
200 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later(
201 self.FORM_TIMEOUT,
202 partial(
203 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"]
204 ),
205 )
206 return adhoc_session
208 if isinstance(result, Confirmation):
209 adhoc_session["next"] = partial(
210 self.__wrap_confirmation, session, result, recipient
211 )
212 adhoc_session["has_next"] = True
213 adhoc_session["payload"] = result.get_form()
214 return adhoc_session
216 if isinstance(result, TableResult):
217 adhoc_session["next"] = None
218 adhoc_session["has_next"] = False
219 adhoc_session["payload"] = result.get_xml()
220 return adhoc_session
222 raise XMPPError("internal-server-error", text="OOPS!")
224 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None:
225 try:
226 del self.xmpp.plugin["xep_0050"].sessions[session_id]
227 except KeyError:
228 log.error("Timeout but session could not be found: %s", session_id)
229 handler()
231 @staticmethod
232 async def __wrap_handler(
233 f: Callable[P, Awaitable[T] | T],
234 *a: P.args,
235 **k: P.kwargs,
236 ) -> T:
237 try:
238 if inspect.iscoroutinefunction(f):
239 return await f(*a, **k) # type:ignore[no-any-return]
240 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
241 return await f(*a, **k) # type:ignore[misc,no-any-return]
242 else:
243 return f(*a, **k) # type:ignore[return-value]
244 except XMPPError:
245 raise
246 except Exception as e:
247 log.debug("Exception in %s", f, exc_info=e)
248 raise XMPPError("internal-server-error", text=str(e))
250 async def __wrap_form_handler(
251 self,
252 session: Optional["BaseSession[Any, Any]"],
253 result: Form,
254 recipient: AnyRecipient | None,
255 form: SlixForm,
256 adhoc_session: AdhocSessionType,
257 ) -> AdhocSessionType:
258 timer = self.__timeouts.pop(adhoc_session["id"], None)
259 if timer is not None:
260 print("canceled", adhoc_session["id"])
261 timer.cancel()
262 form_values = result.get_values(form)
263 if recipient is None:
264 new_result = await self.__wrap_handler(
265 result.handler,
266 form_values,
267 session,
268 adhoc_session["from"],
269 *result.handler_args,
270 **result.handler_kwargs,
271 )
272 else:
273 new_result = await self.__wrap_handler(
274 result.handler,
275 recipient,
276 form_values,
277 *result.handler_args,
278 **result.handler_kwargs,
279 )
280 return await self.__handle_result(session, new_result, adhoc_session, recipient)
282 async def __wrap_confirmation(
283 self,
284 session: Optional["BaseSession[Any, Any]"],
285 confirmation: Confirmation,
286 recipient: AnyRecipient | None,
287 form: SlixForm,
288 adhoc_session: AdhocSessionType,
289 ) -> AdhocSessionType:
290 if form.get_values().get("confirm"):
291 if recipient is None:
292 result = await self.__wrap_handler(
293 confirmation.handler,
294 session,
295 adhoc_session["from"],
296 *confirmation.handler_args,
297 **confirmation.handler_kwargs,
298 )
299 if confirmation.success:
300 result = confirmation.success
301 else:
302 result = await self.__wrap_handler(
303 confirmation.handler,
304 recipient,
305 *confirmation.handler_args,
306 **confirmation.handler_kwargs,
307 )
308 else:
309 result = "You canceled the operation"
311 return await self.__handle_result(session, result, adhoc_session, recipient)
313 def register(self, command: Command[AnySession], jid: JID | None = None) -> None:
314 """
315 Register a command as a adhoc command.
317 this does not need to be called manually, ``BaseGateway`` takes care of
318 that.
320 :param command:
321 :param jid:
322 """
323 if jid is None:
324 jid = self.xmpp.boundjid
325 elif not isinstance(jid, JID):
326 jid = JID(jid)
328 if (category := command.CATEGORY) is None:
329 if command.NODE in self._commands:
330 raise RuntimeError(
331 "There is already a command for the node '%s'", command.NODE
332 )
333 self._commands[command.NODE] = command
334 self.xmpp.plugin["xep_0050"].add_command(
335 jid=jid,
336 node=command.NODE,
337 name=strip_leading_emoji_if_needed(command.NAME),
338 handler=partial(self.__wrap_initial_handler, command),
339 )
340 else:
341 if isinstance(category, str):
342 category = CommandCategory(category, category)
343 node = category.node
344 name = category.name
345 if node not in self._categories:
346 self._categories[node] = list[Command[AnySession]]()
347 self.xmpp.plugin["xep_0050"].add_command(
348 jid=jid,
349 node=node,
350 name=strip_leading_emoji_if_needed(name),
351 handler=partial(self.__handle_category_list, category),
352 )
353 self._categories[node].append(command)
355 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems:
356 """
357 Get items for a disco query
359 :param jid: the entity that should return its items
360 :param node: which command node is requested
361 :param iq: the disco query IQ
362 :return: commands accessible to the given JID will be listed
363 """
364 ifrom = iq.get_from()
365 ifrom_str = str(ifrom)
366 if (
367 not self.xmpp.jid_validator.match(ifrom_str)
368 and ifrom_str not in config.ADMINS
369 ):
370 raise XMPPError(
371 "forbidden",
372 "You are not authorized to execute adhoc commands on this gateway. "
373 "If this is unexpected, ask your administrator to verify that "
374 "'user-jid-validator' is correctly set in slidge's configuration.",
375 )
377 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None)
378 log.debug("Static items: %r", all_items)
379 if not all_items:
380 return DiscoItems()
382 session = self.xmpp.get_session_from_jid(ifrom)
384 filtered_items = DiscoItems()
385 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace
386 for item in all_items:
387 authorized = True
388 if item["node"] in self._categories:
389 for command in self._categories[item["node"]]:
390 try:
391 command.raise_if_not_authorized(
392 ifrom, fetch_session=False, session=session
393 )
394 except XMPPError:
395 authorized = False
396 else:
397 authorized = True
398 break
399 else:
400 try:
401 self._commands[item["node"]].raise_if_not_authorized(
402 ifrom, fetch_session=False, session=session
403 )
404 except XMPPError:
405 authorized = False
407 if authorized:
408 filtered_items.append(item)
410 return filtered_items
413def strip_leading_emoji_if_needed(text: str) -> str:
414 if config.STRIP_LEADING_EMOJI_ADHOC:
415 return strip_leading_emoji(text)
416 return text
419log = logging.getLogger(__name__)