Coverage for slidge / command / adhoc.py: 88%
161 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-02-15 09:02 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-02-15 09:02 +0000
1import asyncio
2import functools
3import inspect
4import logging
5from collections.abc import Callable
6from functools import partial
7from typing import TYPE_CHECKING, Any, Optional
9from slixmpp import JID, Iq
10from slixmpp.exceptions import XMPPError
11from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
12from slixmpp.plugins.xep_0030.stanza.items import DiscoItems
14from ..core import config
15from ..util.util import strip_leading_emoji
16from . import Command, CommandResponseType, Confirmation, Form, TableResult
17from .base import FormField
18from .categories import CommandCategory
20if TYPE_CHECKING:
21 from ..core.gateway import BaseGateway
22 from ..core.session import BaseSession
25AdhocSessionType = dict[str, Any]
28class AdhocProvider:
29 """
30 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and
31 untyped dict values than slixmpp.
32 """
34 FORM_TIMEOUT = 120 # seconds
36 def __init__(self, xmpp: "BaseGateway") -> None:
37 self.xmpp = xmpp
38 self._commands = dict[str, Command]()
39 self._categories = dict[str, list[Command]]()
40 xmpp.plugin["xep_0030"].set_node_handler(
41 "get_items",
42 jid=xmpp.boundjid,
43 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace,
44 handler=self.get_items,
45 )
46 self.__timeouts: dict[str, asyncio.TimerHandle] = {}
48 async def __wrap_initial_handler(
49 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType
50 ) -> AdhocSessionType:
51 ifrom = iq.get_from()
52 session = command.raise_if_not_authorized(ifrom)
53 result = await self.__wrap_handler(command.run, session, ifrom)
54 return await self.__handle_result(session, result, adhoc_session)
56 async def __handle_category_list(
57 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType
58 ) -> AdhocSessionType:
59 try:
60 session = self.xmpp.get_session_from_stanza(iq)
61 except XMPPError:
62 session = None
63 commands: dict[str, Command] = {}
64 for command in self._categories[category.node]:
65 try:
66 command.raise_if_not_authorized(iq.get_from())
67 except XMPPError:
68 continue
69 commands[command.NODE] = command
70 if len(commands) == 0:
71 raise XMPPError(
72 "not-authorized", "There is no command you can run in this category"
73 )
74 return await self.__handle_result(
75 session,
76 Form(
77 category.name,
78 "",
79 [
80 FormField(
81 var="command",
82 label="Command",
83 type="list-single",
84 options=[
85 {
86 "label": strip_leading_emoji_if_needed(command.NAME),
87 "value": command.NODE,
88 }
89 for command in commands.values()
90 ],
91 )
92 ],
93 partial(self.__handle_category_choice, commands),
94 ),
95 adhoc_session,
96 )
98 async def __handle_category_choice(
99 self,
100 commands: dict[str, Command],
101 form_values: dict[str, str],
102 session: "BaseSession[Any, Any]",
103 jid: JID,
104 ):
105 command = commands[form_values["command"]]
106 result = await self.__wrap_handler(command.run, session, jid)
107 return result
109 async def __handle_result(
110 self,
111 session: Optional["BaseSession[Any, Any]"],
112 result: CommandResponseType,
113 adhoc_session: AdhocSessionType,
114 ) -> AdhocSessionType:
115 if isinstance(result, str) or result is None:
116 adhoc_session["has_next"] = False
117 adhoc_session["next"] = None
118 adhoc_session["payload"] = None
119 adhoc_session["notes"] = [("info", result or "Success!")]
120 return adhoc_session
122 if isinstance(result, Form):
123 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result)
124 adhoc_session["has_next"] = True
125 adhoc_session["payload"] = result.get_xml()
126 if result.timeout_handler is not None:
127 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later(
128 self.FORM_TIMEOUT,
129 partial(
130 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"]
131 ),
132 )
133 return adhoc_session
135 if isinstance(result, Confirmation):
136 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
137 adhoc_session["has_next"] = True
138 adhoc_session["payload"] = result.get_form()
139 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
140 return adhoc_session
142 if isinstance(result, TableResult):
143 adhoc_session["next"] = None
144 adhoc_session["has_next"] = False
145 adhoc_session["payload"] = result.get_xml()
146 return adhoc_session
148 raise XMPPError("internal-server-error", text="OOPS!")
150 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None:
151 try:
152 del self.xmpp.plugin["xep_0050"].sessions[session_id]
153 except KeyError:
154 log.error("Timeout but session could not be found: %s", session_id)
155 handler()
157 @staticmethod
158 async def __wrap_handler(f: Callable | functools.partial, *a, **k): # type: ignore
159 try:
160 if inspect.iscoroutinefunction(f):
161 return await f(*a, **k)
162 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
163 return await f(*a, **k)
164 else:
165 return f(*a, **k)
166 except XMPPError:
167 raise
168 except Exception as e:
169 log.debug("Exception in %s", f, exc_info=e)
170 raise XMPPError("internal-server-error", text=str(e))
172 async def __wrap_form_handler(
173 self,
174 session: Optional["BaseSession[Any, Any]"],
175 result: Form,
176 form: SlixForm,
177 adhoc_session: AdhocSessionType,
178 ) -> AdhocSessionType:
179 timer = self.__timeouts.pop(adhoc_session["id"], None)
180 if timer is not None:
181 print("canceled", adhoc_session["id"])
182 timer.cancel()
183 form_values = result.get_values(form)
184 new_result = await self.__wrap_handler(
185 result.handler,
186 form_values,
187 session,
188 adhoc_session["from"],
189 *result.handler_args,
190 **result.handler_kwargs,
191 )
193 return await self.__handle_result(session, new_result, adhoc_session)
195 async def __wrap_confirmation(
196 self,
197 session: Optional["BaseSession[Any, Any]"],
198 confirmation: Confirmation,
199 form: SlixForm,
200 adhoc_session: AdhocSessionType,
201 ) -> AdhocSessionType:
202 if form.get_values().get("confirm"): # type: ignore[no-untyped-call]
203 result = await self.__wrap_handler(
204 confirmation.handler,
205 session,
206 adhoc_session["from"],
207 *confirmation.handler_args,
208 **confirmation.handler_kwargs,
209 )
210 if confirmation.success:
211 result = confirmation.success
212 else:
213 result = "You canceled the operation"
215 return await self.__handle_result(session, result, adhoc_session)
217 def register(self, command: Command, jid: JID | None = None) -> None:
218 """
219 Register a command as a adhoc command.
221 this does not need to be called manually, ``BaseGateway`` takes care of
222 that.
224 :param command:
225 :param jid:
226 """
227 if jid is None:
228 jid = self.xmpp.boundjid
229 elif not isinstance(jid, JID):
230 jid = JID(jid)
232 if (category := command.CATEGORY) is None:
233 if command.NODE in self._commands:
234 raise RuntimeError(
235 "There is already a command for the node '%s'", command.NODE
236 )
237 self._commands[command.NODE] = command
238 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
239 jid=jid,
240 node=command.NODE,
241 name=strip_leading_emoji_if_needed(command.NAME),
242 handler=partial(self.__wrap_initial_handler, command),
243 )
244 else:
245 if isinstance(category, str):
246 category = CommandCategory(category, category)
247 node = category.node
248 name = category.name
249 if node not in self._categories:
250 self._categories[node] = list[Command]()
251 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
252 jid=jid,
253 node=node,
254 name=strip_leading_emoji_if_needed(name),
255 handler=partial(self.__handle_category_list, category),
256 )
257 self._categories[node].append(command)
259 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems:
260 """
261 Get items for a disco query
263 :param jid: the entity that should return its items
264 :param node: which command node is requested
265 :param iq: the disco query IQ
266 :return: commands accessible to the given JID will be listed
267 """
268 ifrom = iq.get_from()
269 ifrom_str = str(ifrom)
270 if (
271 not self.xmpp.jid_validator.match(ifrom_str)
272 and ifrom_str not in config.ADMINS
273 ):
274 raise XMPPError(
275 "forbidden",
276 "You are not authorized to execute adhoc commands on this gateway. "
277 "If this is unexpected, ask your administrator to verify that "
278 "'user-jid-validator' is correctly set in slidge's configuration.",
279 )
281 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None)
282 log.debug("Static items: %r", all_items)
283 if not all_items:
284 return DiscoItems()
286 session = self.xmpp.get_session_from_jid(ifrom)
288 filtered_items = DiscoItems()
289 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace
290 for item in all_items:
291 authorized = True
292 if item["node"] in self._categories:
293 for command in self._categories[item["node"]]:
294 try:
295 command.raise_if_not_authorized(
296 ifrom, fetch_session=False, session=session
297 )
298 except XMPPError:
299 authorized = False
300 else:
301 authorized = True
302 break
303 else:
304 try:
305 self._commands[item["node"]].raise_if_not_authorized(
306 ifrom, fetch_session=False, session=session
307 )
308 except XMPPError:
309 authorized = False
311 if authorized:
312 filtered_items.append(item)
314 return filtered_items
317def strip_leading_emoji_if_needed(text: str) -> str:
318 if config.STRIP_LEADING_EMOJI_ADHOC:
319 return strip_leading_emoji(text)
320 return text
323log = logging.getLogger(__name__)