Coverage for slidge / command / chat_command.py: 57%
185 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-03-13 22:59 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-03-13 22:59 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import functools
7import inspect
8import logging
9from collections.abc import Callable
10from typing import TYPE_CHECKING, Any, Literal, Never, overload
11from urllib.parse import quote as url_quote
13from slixmpp import JID, CoroutineCallback, Message, StanzaPath
14from slixmpp.exceptions import XMPPError
15from slixmpp.types import JidStr, MessageTypes
17from . import Command, CommandResponseType, Confirmation, Form, TableResult
18from .categories import CommandCategory
20if TYPE_CHECKING:
21 from ..core.gateway import BaseGateway
24class ChatCommandProvider:
25 UNKNOWN = "Wut? I don't know that command: {}"
27 def __init__(self, xmpp: "BaseGateway") -> None:
28 self.xmpp = xmpp
29 self._keywords = list[str]()
30 self._commands: dict[str, Command] = {}
31 self._input_futures = dict[str, asyncio.Future[str]]()
32 self.xmpp.register_handler(
33 CoroutineCallback(
34 "chat_command_handler",
35 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
36 self._handle_message, # type: ignore
37 )
38 )
40 def register(self, command: Command) -> None:
41 """
42 Register a command to be used via chat messages with the gateway
44 Plugins should not call this, any class subclassing Command should be
45 automatically added by slidge core.
47 :param command: the new command
48 """
49 t = command.CHAT_COMMAND
50 if t in self._commands:
51 raise RuntimeError("There is already a command triggered by '%s'", t)
52 self._commands[t] = command
54 @overload
55 async def input(self, jid: JidStr, text: str | None = None) -> str: ...
57 @overload
58 async def input(
59 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ...
60 ) -> asyncio.Future[str]: ...
62 @overload
63 async def input(
64 self,
65 jid: JidStr,
66 text: str | None = None,
67 *,
68 mtype: MessageTypes = "chat",
69 timeout: int = 60,
70 blocking: Literal[True] = True,
71 **msg_kwargs: Any,
72 ) -> str: ...
74 async def input(
75 self,
76 jid: JidStr,
77 text: str | None = None,
78 *,
79 mtype: MessageTypes = "chat",
80 timeout: int = 60,
81 blocking: bool = True,
82 **msg_kwargs: Any,
83 ) -> str | asyncio.Future[str]:
84 """
85 Request arbitrary user input using a simple chat message, and await the result.
87 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
88 to directly target a user.
90 NB: When using this, the next message that the user sent to the component will
91 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
92 Await the coroutine to get its content.
94 :param jid: The JID we want input from
95 :param text: A prompt to display for the user
96 :param mtype: Message type
97 :param timeout:
98 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
99 is returned instead of a str
100 :return: The user's reply
101 """
102 jid = JID(jid)
103 if text is not None:
104 self.xmpp.send_message(
105 mto=jid,
106 mbody=text,
107 mtype=mtype,
108 mfrom=self.xmpp.boundjid.bare,
109 **msg_kwargs,
110 )
111 f = asyncio.get_event_loop().create_future()
112 self._input_futures[jid.bare] = f
113 if not blocking:
114 return f
115 try:
116 await asyncio.wait_for(f, timeout)
117 except TimeoutError:
118 self.xmpp.send_message(
119 mto=jid,
120 mbody="You took too much time to reply",
121 mtype=mtype,
122 mfrom=self.xmpp.boundjid.bare,
123 )
124 del self._input_futures[jid.bare]
125 raise XMPPError("remote-server-timeout", "You took too much time to reply")
127 return f.result()
129 async def _handle_message(self, msg: Message) -> None:
130 if not msg["body"]:
131 return
133 if not msg.get_from().node:
134 return # ignore component and server messages
136 f = self._input_futures.pop(msg.get_from().bare, None)
137 if f is not None:
138 f.set_result(msg["body"])
139 return
141 c = msg["body"]
142 first_word, *rest = c.split(" ")
143 first_word = first_word.lower()
145 if first_word == "help":
146 return self._handle_help(msg, *rest)
148 mfrom = msg.get_from()
150 command = self._commands.get(first_word)
151 if command is None:
152 self._not_found(msg, first_word)
153 return
155 try:
156 session = command.raise_if_not_authorized(mfrom)
157 except XMPPError as e:
158 reply = msg.reply()
159 reply["body"] = e.text
160 reply.send()
161 raise
163 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest)
164 self.xmpp.delivery_receipt.ack(msg)
165 return await self._handle_result(result, msg, session)
167 def __make_uri(self, body: str) -> str:
168 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
170 async def _handle_result(self, result: CommandResponseType, msg: Message, session):
171 if isinstance(result, str) or result is None:
172 reply = msg.reply()
173 reply["body"] = result or "End of command."
174 reply.send()
175 return
177 if isinstance(result, Form):
178 try:
179 return await self.__handle_form(result, msg, session)
180 except XMPPError as e:
181 if (
182 result.timeout_handler is None
183 or e.condition != "remote-server-timeout"
184 ):
185 raise e
186 return result.timeout_handler()
188 if isinstance(result, Confirmation):
189 yes_or_no = await self.input(msg.get_from(), result.prompt)
190 if not yes_or_no.lower().startswith("y"):
191 reply = msg.reply()
192 reply["body"] = "Canceled"
193 reply.send()
194 return
195 result = await self.__wrap_handler(
196 msg,
197 result.handler,
198 session,
199 msg.get_from(),
200 *result.handler_args,
201 **result.handler_kwargs,
202 )
203 return await self._handle_result(result, msg, session)
205 if isinstance(result, TableResult):
206 if len(result.items) == 0:
207 msg.reply("Empty results").send()
208 return
210 body = result.description + "\n"
211 for item in result.items:
212 for f in result.fields:
213 if f.type == "jid-single":
214 j = JID(item[f.var])
215 value = f"xmpp:{percent_encode(j)}"
216 if result.jids_are_mucs:
217 value += "?join"
218 else:
219 value = item[f.var] # type:ignore
220 body += f"\n{f.label or f.var}: {value}"
221 msg.reply(body).send()
223 async def __handle_form(self, result: Form, msg: Message, session):
224 form_values = {}
225 for t in result.title, result.instructions:
226 if t:
227 msg.reply(t).send()
228 for f in result.fields:
229 if f.type == "fixed":
230 msg.reply(f"{f.label or f.var}: {f.value}").send()
231 else:
232 if f.type == "list-multi":
233 msg.reply(
234 "Multiple selection allowed, use new lines as a separator, ie, "
235 "one selected item per line. To select no item, reply with a space "
236 "(the punctuation)."
237 ).send()
238 if f.options:
239 for o in f.options:
240 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send()
241 if f.value:
242 msg.reply(f"Default: {f.value}").send()
243 if f.type == "boolean":
244 msg.reply("yes: " + self.__make_uri("yes")).send()
245 msg.reply("no: " + self.__make_uri("no")).send()
247 ans = await self.xmpp.input(
248 msg.get_from(), (f.label or f.var) + "? (or 'abort')"
249 )
250 if ans.lower() == "abort":
251 return await self._handle_result("Command aborted", msg, session)
252 if f.type == "boolean":
253 if ans.lower() == "yes":
254 ans = "true"
255 else:
256 ans = "false"
258 if f.type.endswith("multi"):
259 choices = [] if ans == " " else ans.split("\n")
260 form_values[f.var] = f.validate(choices)
261 else:
262 form_values[f.var] = f.validate(ans)
263 result = await self.__wrap_handler(
264 msg,
265 result.handler,
266 form_values,
267 session,
268 msg.get_from(),
269 *result.handler_args,
270 **result.handler_kwargs,
271 )
272 return await self._handle_result(result, msg, session)
274 @staticmethod
275 async def __wrap_handler(msg, f: Callable | functools.partial, *a, **k):
276 try:
277 if inspect.iscoroutinefunction(f):
278 return await f(*a, **k)
279 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
280 return await f(*a, **k)
281 else:
282 return f(*a, **k)
283 except Exception as e:
284 log.debug("Error in %s", f, exc_info=e)
285 reply = msg.reply()
286 reply["body"] = f"Error: {e}"
287 reply.send()
289 def _handle_help(self, msg: Message, *rest: str) -> None:
290 if len(rest) == 0:
291 reply = msg.reply()
292 reply["body"] = self._help(msg.get_from())
293 reply.send()
294 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
295 reply = msg.reply()
296 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
297 reply.send()
298 else:
299 self._not_found(msg, str(rest))
301 def _help(self, mfrom: JID) -> str:
302 session = self.xmpp.get_session_from_jid(mfrom)
304 msg = "Available commands:"
305 for c in sorted(
306 self._commands.values(),
307 key=lambda co: (
308 (
309 co.CATEGORY
310 if isinstance(co.CATEGORY, str)
311 else (
312 co.CATEGORY.name
313 if isinstance(co.CATEGORY, CommandCategory)
314 else ""
315 )
316 ),
317 co.CHAT_COMMAND,
318 ),
319 ):
320 try:
321 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
322 except XMPPError:
323 continue
324 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
325 return msg
327 def _not_found(self, msg: Message, word: str) -> Never:
328 e = self.UNKNOWN.format(word)
329 msg.reply(e).send()
330 raise XMPPError("item-not-found", e)
333def percent_encode(jid: JID) -> str:
334 return f"{url_quote(jid.user)}@{jid.server}"
337log = logging.getLogger(__name__)