Coverage for slidge / command / adhoc.py: 88%
160 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-06 15:18 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-06 15:18 +0000
1import asyncio
2import functools
3import inspect
4import logging
5from functools import partial
6from typing import TYPE_CHECKING, Any, Callable, Optional, Union
8from slixmpp import JID, Iq
9from slixmpp.exceptions import XMPPError
10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems
13from ..core import config
14from ..util.util import strip_leading_emoji
15from . import Command, CommandResponseType, Confirmation, Form, TableResult
16from .base import FormField
17from .categories import CommandCategory
19if TYPE_CHECKING:
20 from ..core.gateway import BaseGateway
21 from ..core.session import BaseSession
24AdhocSessionType = dict[str, Any]
27class AdhocProvider:
28 """
29 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and
30 untyped dict values than slixmpp.
31 """
33 FORM_TIMEOUT = 120 # seconds
35 def __init__(self, xmpp: "BaseGateway") -> None:
36 self.xmpp = xmpp
37 self._commands = dict[str, Command]()
38 self._categories = dict[str, list[Command]]()
39 xmpp.plugin["xep_0030"].set_node_handler(
40 "get_items",
41 jid=xmpp.boundjid,
42 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace,
43 handler=self.get_items,
44 )
45 self.__timeouts: dict[str, asyncio.TimerHandle] = {}
47 async def __wrap_initial_handler(
48 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType
49 ) -> AdhocSessionType:
50 ifrom = iq.get_from()
51 session = command.raise_if_not_authorized(ifrom)
52 result = await self.__wrap_handler(command.run, session, ifrom)
53 return await self.__handle_result(session, result, adhoc_session)
55 async def __handle_category_list(
56 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType
57 ) -> AdhocSessionType:
58 try:
59 session = self.xmpp.get_session_from_stanza(iq)
60 except XMPPError:
61 session = None
62 commands: dict[str, Command] = {}
63 for command in self._categories[category.node]:
64 try:
65 command.raise_if_not_authorized(iq.get_from())
66 except XMPPError:
67 continue
68 commands[command.NODE] = command
69 if len(commands) == 0:
70 raise XMPPError(
71 "not-authorized", "There is no command you can run in this category"
72 )
73 return await self.__handle_result(
74 session,
75 Form(
76 category.name,
77 "",
78 [
79 FormField(
80 var="command",
81 label="Command",
82 type="list-single",
83 options=[
84 {
85 "label": strip_leading_emoji_if_needed(command.NAME),
86 "value": command.NODE,
87 }
88 for command in commands.values()
89 ],
90 )
91 ],
92 partial(self.__handle_category_choice, commands),
93 ),
94 adhoc_session,
95 )
97 async def __handle_category_choice(
98 self,
99 commands: dict[str, Command],
100 form_values: dict[str, str],
101 session: "BaseSession[Any, Any]",
102 jid: JID,
103 ):
104 command = commands[form_values["command"]]
105 result = await self.__wrap_handler(command.run, session, jid)
106 return result
108 async def __handle_result(
109 self,
110 session: Optional["BaseSession[Any, Any]"],
111 result: CommandResponseType,
112 adhoc_session: AdhocSessionType,
113 ) -> AdhocSessionType:
114 if isinstance(result, str) or result is None:
115 adhoc_session["has_next"] = False
116 adhoc_session["next"] = None
117 adhoc_session["payload"] = None
118 adhoc_session["notes"] = [("info", result or "Success!")]
119 return adhoc_session
121 if isinstance(result, Form):
122 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result)
123 adhoc_session["has_next"] = True
124 adhoc_session["payload"] = result.get_xml()
125 if result.timeout_handler is not None:
126 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later(
127 self.FORM_TIMEOUT,
128 partial(
129 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"]
130 ),
131 )
132 return adhoc_session
134 if isinstance(result, Confirmation):
135 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
136 adhoc_session["has_next"] = True
137 adhoc_session["payload"] = result.get_form()
138 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
139 return adhoc_session
141 if isinstance(result, TableResult):
142 adhoc_session["next"] = None
143 adhoc_session["has_next"] = False
144 adhoc_session["payload"] = result.get_xml()
145 return adhoc_session
147 raise XMPPError("internal-server-error", text="OOPS!")
149 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None:
150 try:
151 del self.xmpp.plugin["xep_0050"].sessions[session_id]
152 except KeyError:
153 log.error("Timeout but session could not be found: %s", session_id)
154 handler()
156 @staticmethod
157 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore
158 try:
159 if inspect.iscoroutinefunction(f):
160 return await f(*a, **k)
161 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
162 return await f(*a, **k)
163 else:
164 return f(*a, **k)
165 except XMPPError:
166 raise
167 except Exception as e:
168 log.debug("Exception in %s", f, exc_info=e)
169 raise XMPPError("internal-server-error", text=str(e))
171 async def __wrap_form_handler(
172 self,
173 session: Optional["BaseSession[Any, Any]"],
174 result: Form,
175 form: SlixForm,
176 adhoc_session: AdhocSessionType,
177 ) -> AdhocSessionType:
178 timer = self.__timeouts.pop(adhoc_session["id"], None)
179 if timer is not None:
180 print("canceled", adhoc_session["id"])
181 timer.cancel()
182 form_values = result.get_values(form)
183 new_result = await self.__wrap_handler(
184 result.handler,
185 form_values,
186 session,
187 adhoc_session["from"],
188 *result.handler_args,
189 **result.handler_kwargs,
190 )
192 return await self.__handle_result(session, new_result, adhoc_session)
194 async def __wrap_confirmation(
195 self,
196 session: Optional["BaseSession[Any, Any]"],
197 confirmation: Confirmation,
198 form: SlixForm,
199 adhoc_session: AdhocSessionType,
200 ) -> AdhocSessionType:
201 if form.get_values().get("confirm"): # type: ignore[no-untyped-call]
202 result = await self.__wrap_handler(
203 confirmation.handler,
204 session,
205 adhoc_session["from"],
206 *confirmation.handler_args,
207 **confirmation.handler_kwargs,
208 )
209 if confirmation.success:
210 result = confirmation.success
211 else:
212 result = "You canceled the operation"
214 return await self.__handle_result(session, result, adhoc_session)
216 def register(self, command: Command, jid: Optional[JID] = None) -> None:
217 """
218 Register a command as a adhoc command.
220 this does not need to be called manually, ``BaseGateway`` takes care of
221 that.
223 :param command:
224 :param jid:
225 """
226 if jid is None:
227 jid = self.xmpp.boundjid
228 elif not isinstance(jid, JID):
229 jid = JID(jid)
231 if (category := command.CATEGORY) is None:
232 if command.NODE in self._commands:
233 raise RuntimeError(
234 "There is already a command for the node '%s'", command.NODE
235 )
236 self._commands[command.NODE] = command
237 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
238 jid=jid,
239 node=command.NODE,
240 name=strip_leading_emoji_if_needed(command.NAME),
241 handler=partial(self.__wrap_initial_handler, command),
242 )
243 else:
244 if isinstance(category, str):
245 category = CommandCategory(category, category)
246 node = category.node
247 name = category.name
248 if node not in self._categories:
249 self._categories[node] = list[Command]()
250 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
251 jid=jid,
252 node=node,
253 name=strip_leading_emoji_if_needed(name),
254 handler=partial(self.__handle_category_list, category),
255 )
256 self._categories[node].append(command)
258 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems:
259 """
260 Get items for a disco query
262 :param jid: the entity that should return its items
263 :param node: which command node is requested
264 :param iq: the disco query IQ
265 :return: commands accessible to the given JID will be listed
266 """
267 ifrom = iq.get_from()
268 ifrom_str = str(ifrom)
269 if (
270 not self.xmpp.jid_validator.match(ifrom_str)
271 and ifrom_str not in config.ADMINS
272 ):
273 raise XMPPError(
274 "forbidden",
275 "You are not authorized to execute adhoc commands on this gateway. "
276 "If this is unexpected, ask your administrator to verify that "
277 "'user-jid-validator' is correctly set in slidge's configuration.",
278 )
280 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None)
281 log.debug("Static items: %r", all_items)
282 if not all_items:
283 return DiscoItems()
285 session = self.xmpp.get_session_from_jid(ifrom)
287 filtered_items = DiscoItems()
288 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace
289 for item in all_items:
290 authorized = True
291 if item["node"] in self._categories:
292 for command in self._categories[item["node"]]:
293 try:
294 command.raise_if_not_authorized(
295 ifrom, fetch_session=False, session=session
296 )
297 except XMPPError:
298 authorized = False
299 else:
300 authorized = True
301 break
302 else:
303 try:
304 self._commands[item["node"]].raise_if_not_authorized(
305 ifrom, fetch_session=False, session=session
306 )
307 except XMPPError:
308 authorized = False
310 if authorized:
311 filtered_items.append(item)
313 return filtered_items
316def strip_leading_emoji_if_needed(text: str) -> str:
317 if config.STRIP_LEADING_EMOJI_ADHOC:
318 return strip_leading_emoji(text)
319 return text
322log = logging.getLogger(__name__)