Coverage for slidge/command/adhoc.py: 92%
139 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1import asyncio
2import functools
3import logging
4from functools import partial
5from typing import TYPE_CHECKING, Any, Callable, Optional, Union
7from slixmpp import JID, Iq # type: ignore[attr-defined]
8from slixmpp.exceptions import XMPPError
9from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined]
10from slixmpp.plugins.xep_0030.stanza.items import DiscoItems
12from ..core import config
13from ..util.util import strip_leading_emoji
14from . import Command, CommandResponseType, Confirmation, Form, TableResult
15from .base import FormField
16from .categories import CommandCategory
18if TYPE_CHECKING:
19 from ..core.gateway import BaseGateway
20 from ..core.session import BaseSession
23AdhocSessionType = dict[str, Any]
26class AdhocProvider:
27 """
28 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and
29 untyped dict values than slixmpp.
30 """
32 def __init__(self, xmpp: "BaseGateway") -> None:
33 self.xmpp = xmpp
34 self._commands = dict[str, Command]()
35 self._categories = dict[str, list[Command]]()
36 xmpp.plugin["xep_0030"].set_node_handler(
37 "get_items",
38 jid=xmpp.boundjid,
39 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace,
40 handler=self.get_items,
41 )
43 async def __wrap_initial_handler(
44 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType
45 ) -> AdhocSessionType:
46 ifrom = iq.get_from()
47 session = command.raise_if_not_authorized(ifrom)
48 result = await self.__wrap_handler(command.run, session, ifrom)
49 return await self.__handle_result(session, result, adhoc_session)
51 async def __handle_category_list(
52 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType
53 ) -> AdhocSessionType:
54 try:
55 session = self.xmpp.get_session_from_stanza(iq)
56 except XMPPError:
57 session = None
58 commands: dict[str, Command] = {}
59 for command in self._categories[category.node]:
60 try:
61 command.raise_if_not_authorized(iq.get_from())
62 except XMPPError:
63 continue
64 commands[command.NODE] = command
65 if len(commands) == 0:
66 raise XMPPError(
67 "not-authorized", "There is no command you can run in this category"
68 )
69 return await self.__handle_result(
70 session,
71 Form(
72 category.name,
73 "",
74 [
75 FormField(
76 var="command",
77 label="Command",
78 type="list-single",
79 options=[
80 {
81 "label": strip_leading_emoji_if_needed(command.NAME),
82 "value": command.NODE,
83 }
84 for command in commands.values()
85 ],
86 )
87 ],
88 partial(self.__handle_category_choice, commands),
89 ),
90 adhoc_session,
91 )
93 async def __handle_category_choice(
94 self,
95 commands: dict[str, Command],
96 form_values: dict[str, str],
97 session: "BaseSession[Any, Any]",
98 jid: JID,
99 ):
100 command = commands[form_values["command"]]
101 result = await self.__wrap_handler(command.run, session, jid)
102 return result
104 async def __handle_result(
105 self,
106 session: Optional["BaseSession[Any, Any]"],
107 result: CommandResponseType,
108 adhoc_session: AdhocSessionType,
109 ) -> AdhocSessionType:
110 if isinstance(result, str) or result is None:
111 adhoc_session["has_next"] = False
112 adhoc_session["next"] = None
113 adhoc_session["payload"] = None
114 adhoc_session["notes"] = [("info", result or "Success!")]
115 return adhoc_session
117 if isinstance(result, Form):
118 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result)
119 adhoc_session["has_next"] = True
120 adhoc_session["payload"] = result.get_xml()
121 return adhoc_session
123 if isinstance(result, Confirmation):
124 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
125 adhoc_session["has_next"] = True
126 adhoc_session["payload"] = result.get_form()
127 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result)
128 return adhoc_session
130 if isinstance(result, TableResult):
131 adhoc_session["next"] = None
132 adhoc_session["has_next"] = False
133 adhoc_session["payload"] = result.get_xml()
134 return adhoc_session
136 raise XMPPError("internal-server-error", text="OOPS!")
138 @staticmethod
139 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore
140 try:
141 if asyncio.iscoroutinefunction(f):
142 return await f(*a, **k)
143 elif hasattr(f, "func") and asyncio.iscoroutinefunction(f.func):
144 return await f(*a, **k)
145 else:
146 return f(*a, **k)
147 except Exception as e:
148 log.debug("Exception in %s", f, exc_info=e)
149 raise XMPPError("internal-server-error", text=str(e))
151 async def __wrap_form_handler(
152 self,
153 session: Optional["BaseSession[Any, Any]"],
154 result: Form,
155 form: SlixForm,
156 adhoc_session: AdhocSessionType,
157 ) -> AdhocSessionType:
158 form_values = result.get_values(form)
159 new_result = await self.__wrap_handler(
160 result.handler,
161 form_values,
162 session,
163 adhoc_session["from"],
164 *result.handler_args,
165 **result.handler_kwargs,
166 )
168 return await self.__handle_result(session, new_result, adhoc_session)
170 async def __wrap_confirmation(
171 self,
172 session: Optional["BaseSession[Any, Any]"],
173 confirmation: Confirmation,
174 form: SlixForm,
175 adhoc_session: AdhocSessionType,
176 ) -> AdhocSessionType:
177 if form.get_values().get("confirm"): # type: ignore[no-untyped-call]
178 result = await self.__wrap_handler(
179 confirmation.handler,
180 session,
181 adhoc_session["from"],
182 *confirmation.handler_args,
183 **confirmation.handler_kwargs,
184 )
185 if confirmation.success:
186 result = confirmation.success
187 else:
188 result = "You canceled the operation"
190 return await self.__handle_result(session, result, adhoc_session)
192 def register(self, command: Command, jid: Optional[JID] = None) -> None:
193 """
194 Register a command as a adhoc command.
196 this does not need to be called manually, ``BaseGateway`` takes care of
197 that.
199 :param command:
200 :param jid:
201 """
202 if jid is None:
203 jid = self.xmpp.boundjid
204 elif not isinstance(jid, JID):
205 jid = JID(jid)
207 if (category := command.CATEGORY) is None:
208 if command.NODE in self._commands:
209 raise RuntimeError(
210 "There is already a command for the node '%s'", command.NODE
211 )
212 self._commands[command.NODE] = command
213 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
214 jid=jid,
215 node=command.NODE,
216 name=strip_leading_emoji_if_needed(command.NAME),
217 handler=partial(self.__wrap_initial_handler, command),
218 )
219 else:
220 if isinstance(category, str):
221 category = CommandCategory(category, category)
222 node = category.node
223 name = category.name
224 if node not in self._categories:
225 self._categories[node] = list[Command]()
226 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call]
227 jid=jid,
228 node=node,
229 name=strip_leading_emoji_if_needed(name),
230 handler=partial(self.__handle_category_list, category),
231 )
232 self._categories[node].append(command)
234 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems:
235 """
236 Get items for a disco query
238 :param jid: who is requesting the disco
239 :param node: which command node is requested
240 :param iq: the disco query IQ
241 :return: commands accessible to the given JID will be listed
242 """
243 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None)
244 log.debug("Static items: %r", all_items)
245 if not all_items:
246 return DiscoItems()
248 ifrom = iq.get_from()
250 filtered_items = DiscoItems()
251 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace
252 for item in all_items:
253 authorized = True
254 if item["node"] in self._categories:
255 for command in self._categories[item["node"]]:
256 try:
257 command.raise_if_not_authorized(ifrom)
258 except XMPPError:
259 authorized = False
260 else:
261 authorized = True
262 break
263 else:
264 try:
265 self._commands[item["node"]].raise_if_not_authorized(ifrom)
266 except XMPPError:
267 authorized = False
269 if authorized:
270 filtered_items.append(item)
272 return filtered_items
275def strip_leading_emoji_if_needed(text: str) -> str:
276 if config.STRIP_LEADING_EMOJI_ADHOC:
277 return strip_leading_emoji(text)
278 return text
281log = logging.getLogger(__name__)