Coverage for slidge/command/adhoc.py: 88%
158 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-26 19:34 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-26 19:34 +0000
1import asyncio
2import functools
3import inspect
4import logging
5from functools import partial
6from typing import TYPE_CHECKING, Any, Callable, Optional, Union
8from slixmpp import JID, Iq
9from slixmpp.exceptions import XMPPError
10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems
13from ..core import config
14from ..util.util import strip_leading_emoji
15from . import Command, CommandResponseType, Confirmation, Form, TableResult
16from .base import FormField
17from .categories import CommandCategory
19if TYPE_CHECKING:
20 from ..core.gateway import BaseGateway
21 from ..core.session import BaseSession
24AdhocSessionType = dict[str, Any]
27class AdhocProvider:
28 """
29 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and
30 untyped dict values than slixmpp.
31 """
33 FORM_TIMEOUT = 120 # seconds
35 def __init__(self, xmpp: "BaseGateway") -> None:
36 self.xmpp = xmpp
37 self._commands = dict[str, Command]()
38 self._categories = dict[str, list[Command]]()
39 xmpp.plugin["xep_0030"].set_node_handler(
40 "get_items",
41 jid=xmpp.boundjid,
42 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace,
43 handler=self.get_items,
44 )
45 self.__timeouts: dict[str, asyncio.TimerHandle] = {}
47 async def __wrap_initial_handler(
48 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType
49 ) -> AdhocSessionType:
50 ifrom = iq.get_from()
51 session = command.raise_if_not_authorized(ifrom)
52 result = await self.__wrap_handler(command.run, session, ifrom)
53 return await self.__handle_result(session, result, adhoc_session)
55 async def __handle_category_list(
56 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType
57 ) -> AdhocSessionType:
58 try:
59 session = self.xmpp.get_session_from_stanza(iq)
60 except XMPPError:
61 session = None
62 commands: dict[str, Command] = {}
63 for command in self._categories[category.node]:
64 try:
65 command.raise_if_not_authorized(iq.get_from())
66 except XMPPError:
67 continue
68 commands[command.NODE] = command
69 if len(commands) == 0:
70 raise XMPPError(
71 "not-authorized", "There is no command you can run in this category"
72 )
73 return await self.__handle_result(
74 session,
75 Form(
76 category.name,
77 "",
78 [
79 FormField(
80 var="command",
81 label="Command",
82 type="list-single",
83 options=[
84 {
85 "label": strip_leading_emoji_if_needed(command.NAME),
86 "value": command.NODE,
87 }
88 for command in commands.values()
89 ],
90 )
91 ],
92 partial(self.__handle_category_choice, commands),
93 ),
94 adhoc_session,
95 )
97 async def __handle_category_choice(
98 self,
99 commands: dict[str, Command],
100 form_values: dict[str, str],
101 session: "BaseSession[Any, Any]",
102 jid: JID,
103 ):
104 command = commands[form_values["command"]]
105 result = await self.__wrap_handler(command.run, session, jid)
106 return result
108 async def __handle_result(
109 self,
110 session: Optional["BaseSession[Any, Any]"],
111 result: CommandResponseType,
112 adhoc_session: AdhocSessionType,
113 ) -> AdhocSessionType:
114 if isinstance(result, str) or result is None:
115 adhoc_session["has_next"] = False
116 adhoc_session["next"] = None
117 adhoc_session["payload"] = None
118 adhoc_session["notes"] = [("info", result or "Success!")]
119 return adhoc_session
121 if isinstance(result, Form):
122 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result)
123 adhoc_session["has_next"] = True
124 adhoc_session["payload"] = result.get_xml()
125 if result.timeout_handler is not None:
126 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later(
127 self.FORM_TIMEOUT,
128 partial(
129 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"]
130 ),
131 )
132 return adhoc_session
134 if isinstance(result, Confirmation):
135 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
136 adhoc_session["has_next"] = True
137 adhoc_session["payload"] = result.get_form()
138 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
139 return adhoc_session
141 if isinstance(result, TableResult):
142 adhoc_session["next"] = None
143 adhoc_session["has_next"] = False
144 adhoc_session["payload"] = result.get_xml()
145 return adhoc_session
147 raise XMPPError("internal-server-error", text="OOPS!")
149 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None:
150 try:
151 del self.xmpp.plugin["xep_0050"].sessions[session_id]
152 except KeyError:
153 log.error("Timeout but session could not be found: %s", session_id)
154 handler()
156 @staticmethod
157 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore
158 try:
159 if inspect.iscoroutinefunction(f):
160 return await f(*a, **k)
161 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
162 return await f(*a, **k)
163 else:
164 return f(*a, **k)
165 except Exception as e:
166 log.debug("Exception in %s", f, exc_info=e)
167 raise XMPPError("internal-server-error", text=str(e))
169 async def __wrap_form_handler(
170 self,
171 session: Optional["BaseSession[Any, Any]"],
172 result: Form,
173 form: SlixForm,
174 adhoc_session: AdhocSessionType,
175 ) -> AdhocSessionType:
176 timer = self.__timeouts.pop(adhoc_session["id"], None)
177 if timer is not None:
178 print("canceled", adhoc_session["id"])
179 timer.cancel()
180 form_values = result.get_values(form)
181 new_result = await self.__wrap_handler(
182 result.handler,
183 form_values,
184 session,
185 adhoc_session["from"],
186 *result.handler_args,
187 **result.handler_kwargs,
188 )
190 return await self.__handle_result(session, new_result, adhoc_session)
192 async def __wrap_confirmation(
193 self,
194 session: Optional["BaseSession[Any, Any]"],
195 confirmation: Confirmation,
196 form: SlixForm,
197 adhoc_session: AdhocSessionType,
198 ) -> AdhocSessionType:
199 if form.get_values().get("confirm"): # type: ignore[no-untyped-call]
200 result = await self.__wrap_handler(
201 confirmation.handler,
202 session,
203 adhoc_session["from"],
204 *confirmation.handler_args,
205 **confirmation.handler_kwargs,
206 )
207 if confirmation.success:
208 result = confirmation.success
209 else:
210 result = "You canceled the operation"
212 return await self.__handle_result(session, result, adhoc_session)
214 def register(self, command: Command, jid: Optional[JID] = None) -> None:
215 """
216 Register a command as a adhoc command.
218 this does not need to be called manually, ``BaseGateway`` takes care of
219 that.
221 :param command:
222 :param jid:
223 """
224 if jid is None:
225 jid = self.xmpp.boundjid
226 elif not isinstance(jid, JID):
227 jid = JID(jid)
229 if (category := command.CATEGORY) is None:
230 if command.NODE in self._commands:
231 raise RuntimeError(
232 "There is already a command for the node '%s'", command.NODE
233 )
234 self._commands[command.NODE] = command
235 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
236 jid=jid,
237 node=command.NODE,
238 name=strip_leading_emoji_if_needed(command.NAME),
239 handler=partial(self.__wrap_initial_handler, command),
240 )
241 else:
242 if isinstance(category, str):
243 category = CommandCategory(category, category)
244 node = category.node
245 name = category.name
246 if node not in self._categories:
247 self._categories[node] = list[Command]()
248 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
249 jid=jid,
250 node=node,
251 name=strip_leading_emoji_if_needed(name),
252 handler=partial(self.__handle_category_list, category),
253 )
254 self._categories[node].append(command)
256 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems:
257 """
258 Get items for a disco query
260 :param jid: the entity that should return its items
261 :param node: which command node is requested
262 :param iq: the disco query IQ
263 :return: commands accessible to the given JID will be listed
264 """
265 ifrom = iq.get_from()
266 ifrom_str = str(ifrom)
267 if (
268 not self.xmpp.jid_validator.match(ifrom_str)
269 and ifrom_str not in config.ADMINS
270 ):
271 raise XMPPError(
272 "forbidden",
273 "You are not authorized to execute adhoc commands on this gateway. "
274 "If this is unexpected, ask your administrator to verify that "
275 "'user-jid-validator' is correctly set in slidge's configuration.",
276 )
278 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None)
279 log.debug("Static items: %r", all_items)
280 if not all_items:
281 return DiscoItems()
283 session = self.xmpp.get_session_from_jid(ifrom)
285 filtered_items = DiscoItems()
286 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace
287 for item in all_items:
288 authorized = True
289 if item["node"] in self._categories:
290 for command in self._categories[item["node"]]:
291 try:
292 command.raise_if_not_authorized(
293 ifrom, fetch_session=False, session=session
294 )
295 except XMPPError:
296 authorized = False
297 else:
298 authorized = True
299 break
300 else:
301 try:
302 self._commands[item["node"]].raise_if_not_authorized(
303 ifrom, fetch_session=False, session=session
304 )
305 except XMPPError:
306 authorized = False
308 if authorized:
309 filtered_items.append(item)
311 return filtered_items
314def strip_leading_emoji_if_needed(text: str) -> str:
315 if config.STRIP_LEADING_EMOJI_ADHOC:
316 return strip_leading_emoji(text)
317 return text
320log = logging.getLogger(__name__)