Coverage for slidge / command / chat_command.py: 57%
184 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-02-15 09:02 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-02-15 09:02 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import functools
7import inspect
8import logging
9from collections.abc import Callable
10from typing import TYPE_CHECKING, Any, Literal, overload
11from urllib.parse import quote as url_quote
13from slixmpp import JID, CoroutineCallback, Message, StanzaPath
14from slixmpp.exceptions import XMPPError
15from slixmpp.types import JidStr, MessageTypes
17from . import Command, CommandResponseType, Confirmation, Form, TableResult
18from .categories import CommandCategory
20if TYPE_CHECKING:
21 from ..core.gateway import BaseGateway
24class ChatCommandProvider:
25 UNKNOWN = "Wut? I don't know that command: {}"
27 def __init__(self, xmpp: "BaseGateway") -> None:
28 self.xmpp = xmpp
29 self._keywords = list[str]()
30 self._commands: dict[str, Command] = {}
31 self._input_futures = dict[str, asyncio.Future[str]]()
32 self.xmpp.register_handler(
33 CoroutineCallback(
34 "chat_command_handler",
35 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
36 self._handle_message, # type: ignore
37 )
38 )
40 def register(self, command: Command):
41 """
42 Register a command to be used via chat messages with the gateway
44 Plugins should not call this, any class subclassing Command should be
45 automatically added by slidge core.
47 :param command: the new command
48 """
49 t = command.CHAT_COMMAND
50 if t in self._commands:
51 raise RuntimeError("There is already a command triggered by '%s'", t)
52 self._commands[t] = command
54 @overload
55 async def input(self, jid: JidStr, text: str | None = None) -> str: ...
57 @overload
58 async def input(
59 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ...
60 ) -> asyncio.Future[str]: ...
62 @overload
63 async def input(
64 self,
65 jid: JidStr,
66 text: str | None = None,
67 *,
68 mtype: MessageTypes = "chat",
69 timeout: int = 60,
70 blocking: Literal[True] = True,
71 **msg_kwargs: Any,
72 ) -> str: ...
74 async def input(
75 self,
76 jid: JidStr,
77 text: str | None = None,
78 *,
79 mtype: MessageTypes = "chat",
80 timeout: int = 60,
81 blocking: bool = True,
82 **msg_kwargs: Any,
83 ) -> str | asyncio.Future[str]:
84 """
85 Request arbitrary user input using a simple chat message, and await the result.
87 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
88 to directly target a user.
90 NB: When using this, the next message that the user sent to the component will
91 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
92 Await the coroutine to get its content.
94 :param jid: The JID we want input from
95 :param text: A prompt to display for the user
96 :param mtype: Message type
97 :param timeout:
98 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
99 is returned instead of a str
100 :return: The user's reply
101 """
102 jid = JID(jid)
103 if text is not None:
104 self.xmpp.send_message(
105 mto=jid,
106 mbody=text,
107 mtype=mtype,
108 mfrom=self.xmpp.boundjid.bare,
109 **msg_kwargs,
110 )
111 f = asyncio.get_event_loop().create_future()
112 self._input_futures[jid.bare] = f
113 if not blocking:
114 return f
115 try:
116 await asyncio.wait_for(f, timeout)
117 except TimeoutError:
118 self.xmpp.send_message(
119 mto=jid,
120 mbody="You took too much time to reply",
121 mtype=mtype,
122 mfrom=self.xmpp.boundjid.bare,
123 )
124 del self._input_futures[jid.bare]
125 raise XMPPError("remote-server-timeout", "You took too much time to reply")
127 return f.result()
129 async def _handle_message(self, msg: Message):
130 if not msg["body"]:
131 return
133 if not msg.get_from().node:
134 return # ignore component and server messages
136 f = self._input_futures.pop(msg.get_from().bare, None)
137 if f is not None:
138 f.set_result(msg["body"])
139 return
141 c = msg["body"]
142 first_word, *rest = c.split(" ")
143 first_word = first_word.lower()
145 if first_word == "help":
146 return self._handle_help(msg, *rest)
148 mfrom = msg.get_from()
150 command = self._commands.get(first_word)
151 if command is None:
152 return self._not_found(msg, first_word)
154 try:
155 session = command.raise_if_not_authorized(mfrom)
156 except XMPPError as e:
157 reply = msg.reply()
158 reply["body"] = e.text
159 reply.send()
160 raise
162 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest)
163 self.xmpp.delivery_receipt.ack(msg)
164 return await self._handle_result(result, msg, session)
166 def __make_uri(self, body: str) -> str:
167 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
169 async def _handle_result(self, result: CommandResponseType, msg: Message, session):
170 if isinstance(result, str) or result is None:
171 reply = msg.reply()
172 reply["body"] = result or "End of command."
173 reply.send()
174 return
176 if isinstance(result, Form):
177 try:
178 return await self.__handle_form(result, msg, session)
179 except XMPPError as e:
180 if (
181 result.timeout_handler is None
182 or e.condition != "remote-server-timeout"
183 ):
184 raise e
185 return result.timeout_handler()
187 if isinstance(result, Confirmation):
188 yes_or_no = await self.input(msg.get_from(), result.prompt)
189 if not yes_or_no.lower().startswith("y"):
190 reply = msg.reply()
191 reply["body"] = "Canceled"
192 reply.send()
193 return
194 result = await self.__wrap_handler(
195 msg,
196 result.handler,
197 session,
198 msg.get_from(),
199 *result.handler_args,
200 **result.handler_kwargs,
201 )
202 return await self._handle_result(result, msg, session)
204 if isinstance(result, TableResult):
205 if len(result.items) == 0:
206 msg.reply("Empty results").send()
207 return
209 body = result.description + "\n"
210 for item in result.items:
211 for f in result.fields:
212 if f.type == "jid-single":
213 j = JID(item[f.var])
214 value = f"xmpp:{percent_encode(j)}"
215 if result.jids_are_mucs:
216 value += "?join"
217 else:
218 value = item[f.var] # type:ignore
219 body += f"\n{f.label or f.var}: {value}"
220 msg.reply(body).send()
222 async def __handle_form(self, result: Form, msg: Message, session):
223 form_values = {}
224 for t in result.title, result.instructions:
225 if t:
226 msg.reply(t).send()
227 for f in result.fields:
228 if f.type == "fixed":
229 msg.reply(f"{f.label or f.var}: {f.value}").send()
230 else:
231 if f.type == "list-multi":
232 msg.reply(
233 "Multiple selection allowed, use new lines as a separator, ie, "
234 "one selected item per line. To select no item, reply with a space "
235 "(the punctuation)."
236 ).send()
237 if f.options:
238 for o in f.options:
239 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send()
240 if f.value:
241 msg.reply(f"Default: {f.value}").send()
242 if f.type == "boolean":
243 msg.reply("yes: " + self.__make_uri("yes")).send()
244 msg.reply("no: " + self.__make_uri("no")).send()
246 ans = await self.xmpp.input(
247 msg.get_from(), (f.label or f.var) + "? (or 'abort')"
248 )
249 if ans.lower() == "abort":
250 return await self._handle_result("Command aborted", msg, session)
251 if f.type == "boolean":
252 if ans.lower() == "yes":
253 ans = "true"
254 else:
255 ans = "false"
257 if f.type.endswith("multi"):
258 choices = [] if ans == " " else ans.split("\n")
259 form_values[f.var] = f.validate(choices)
260 else:
261 form_values[f.var] = f.validate(ans)
262 result = await self.__wrap_handler(
263 msg,
264 result.handler,
265 form_values,
266 session,
267 msg.get_from(),
268 *result.handler_args,
269 **result.handler_kwargs,
270 )
271 return await self._handle_result(result, msg, session)
273 @staticmethod
274 async def __wrap_handler(msg, f: Callable | functools.partial, *a, **k):
275 try:
276 if inspect.iscoroutinefunction(f):
277 return await f(*a, **k)
278 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
279 return await f(*a, **k)
280 else:
281 return f(*a, **k)
282 except Exception as e:
283 log.debug("Error in %s", f, exc_info=e)
284 reply = msg.reply()
285 reply["body"] = f"Error: {e}"
286 reply.send()
288 def _handle_help(self, msg: Message, *rest) -> None:
289 if len(rest) == 0:
290 reply = msg.reply()
291 reply["body"] = self._help(msg.get_from())
292 reply.send()
293 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
294 reply = msg.reply()
295 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
296 reply.send()
297 else:
298 self._not_found(msg, str(rest))
300 def _help(self, mfrom: JID):
301 session = self.xmpp.get_session_from_jid(mfrom)
303 msg = "Available commands:"
304 for c in sorted(
305 self._commands.values(),
306 key=lambda co: (
307 (
308 co.CATEGORY
309 if isinstance(co.CATEGORY, str)
310 else (
311 co.CATEGORY.name
312 if isinstance(co.CATEGORY, CommandCategory)
313 else ""
314 )
315 ),
316 co.CHAT_COMMAND,
317 ),
318 ):
319 try:
320 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
321 except XMPPError:
322 continue
323 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
324 return msg
326 def _not_found(self, msg: Message, word: str):
327 e = self.UNKNOWN.format(word)
328 msg.reply(e).send()
329 raise XMPPError("item-not-found", e)
332def percent_encode(jid: JID) -> str:
333 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore
336log = logging.getLogger(__name__)