Coverage for slidge/command/chat_command.py: 57%
183 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-26 19:34 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-26 19:34 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import functools
7import inspect
8import logging
9from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, overload
10from urllib.parse import quote as url_quote
12from slixmpp import JID, CoroutineCallback, Message, StanzaPath
13from slixmpp.exceptions import XMPPError
14from slixmpp.types import JidStr, MessageTypes
16from . import Command, CommandResponseType, Confirmation, Form, TableResult
17from .categories import CommandCategory
19if TYPE_CHECKING:
20 from ..core.gateway import BaseGateway
23class ChatCommandProvider:
24 UNKNOWN = "Wut? I don't know that command: {}"
26 def __init__(self, xmpp: "BaseGateway") -> None:
27 self.xmpp = xmpp
28 self._keywords = list[str]()
29 self._commands: dict[str, Command] = {}
30 self._input_futures = dict[str, asyncio.Future[str]]()
31 self.xmpp.register_handler(
32 CoroutineCallback(
33 "chat_command_handler",
34 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
35 self._handle_message, # type: ignore
36 )
37 )
39 def register(self, command: Command):
40 """
41 Register a command to be used via chat messages with the gateway
43 Plugins should not call this, any class subclassing Command should be
44 automatically added by slidge core.
46 :param command: the new command
47 """
48 t = command.CHAT_COMMAND
49 if t in self._commands:
50 raise RuntimeError("There is already a command triggered by '%s'", t)
51 self._commands[t] = command
53 @overload
54 async def input(self, jid: JidStr, text: Optional[str] = None) -> str: ...
56 @overload
57 async def input(
58 self, jid: JidStr, text: Optional[str] = None, *, blocking: Literal[False] = ...
59 ) -> asyncio.Future[str]: ...
61 @overload
62 async def input(
63 self,
64 jid: JidStr,
65 text: str | None = None,
66 *,
67 mtype: MessageTypes = "chat",
68 timeout: int = 60,
69 blocking: Literal[True] = True,
70 **msg_kwargs: Any,
71 ) -> str: ...
73 async def input(
74 self,
75 jid: JidStr,
76 text: str | None = None,
77 *,
78 mtype: MessageTypes = "chat",
79 timeout: int = 60,
80 blocking: bool = True,
81 **msg_kwargs: Any,
82 ) -> str | asyncio.Future[str]:
83 """
84 Request arbitrary user input using a simple chat message, and await the result.
86 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
87 to directly target a user.
89 NB: When using this, the next message that the user sent to the component will
90 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
91 Await the coroutine to get its content.
93 :param jid: The JID we want input from
94 :param text: A prompt to display for the user
95 :param mtype: Message type
96 :param timeout:
97 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
98 is returned instead of a str
99 :return: The user's reply
100 """
101 jid = JID(jid)
102 if text is not None:
103 self.xmpp.send_message(
104 mto=jid,
105 mbody=text,
106 mtype=mtype,
107 mfrom=self.xmpp.boundjid.bare,
108 **msg_kwargs,
109 )
110 f = asyncio.get_event_loop().create_future()
111 self._input_futures[jid.bare] = f
112 if not blocking:
113 return f
114 try:
115 await asyncio.wait_for(f, timeout)
116 except asyncio.TimeoutError:
117 self.xmpp.send_message(
118 mto=jid,
119 mbody="You took too much time to reply",
120 mtype=mtype,
121 mfrom=self.xmpp.boundjid.bare,
122 )
123 del self._input_futures[jid.bare]
124 raise XMPPError("remote-server-timeout", "You took too much time to reply")
126 return f.result()
128 async def _handle_message(self, msg: Message):
129 if not msg["body"]:
130 return
132 if not msg.get_from().node:
133 return # ignore component and server messages
135 f = self._input_futures.pop(msg.get_from().bare, None)
136 if f is not None:
137 f.set_result(msg["body"])
138 return
140 c = msg["body"]
141 first_word, *rest = c.split(" ")
142 first_word = first_word.lower()
144 if first_word == "help":
145 return self._handle_help(msg, *rest)
147 mfrom = msg.get_from()
149 command = self._commands.get(first_word)
150 if command is None:
151 return self._not_found(msg, first_word)
153 try:
154 session = command.raise_if_not_authorized(mfrom)
155 except XMPPError as e:
156 reply = msg.reply()
157 reply["body"] = e.text
158 reply.send()
159 raise
161 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest)
162 self.xmpp.delivery_receipt.ack(msg)
163 return await self._handle_result(result, msg, session)
165 def __make_uri(self, body: str) -> str:
166 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
168 async def _handle_result(self, result: CommandResponseType, msg: Message, session):
169 if isinstance(result, str) or result is None:
170 reply = msg.reply()
171 reply["body"] = result or "End of command."
172 reply.send()
173 return
175 if isinstance(result, Form):
176 try:
177 return await self.__handle_form(result, msg, session)
178 except XMPPError as e:
179 if (
180 result.timeout_handler is None
181 or e.condition != "remote-server-timeout"
182 ):
183 raise e
184 return result.timeout_handler()
186 if isinstance(result, Confirmation):
187 yes_or_no = await self.input(msg.get_from(), result.prompt)
188 if not yes_or_no.lower().startswith("y"):
189 reply = msg.reply()
190 reply["body"] = "Canceled"
191 reply.send()
192 return
193 result = await self.__wrap_handler(
194 msg,
195 result.handler,
196 session,
197 msg.get_from(),
198 *result.handler_args,
199 **result.handler_kwargs,
200 )
201 return await self._handle_result(result, msg, session)
203 if isinstance(result, TableResult):
204 if len(result.items) == 0:
205 msg.reply("Empty results").send()
206 return
208 body = result.description + "\n"
209 for item in result.items:
210 for f in result.fields:
211 if f.type == "jid-single":
212 j = JID(item[f.var])
213 value = f"xmpp:{percent_encode(j)}"
214 if result.jids_are_mucs:
215 value += "?join"
216 else:
217 value = item[f.var] # type:ignore
218 body += f"\n{f.label or f.var}: {value}"
219 msg.reply(body).send()
221 async def __handle_form(self, result: Form, msg: Message, session):
222 form_values = {}
223 for t in result.title, result.instructions:
224 if t:
225 msg.reply(t).send()
226 for f in result.fields:
227 if f.type == "fixed":
228 msg.reply(f"{f.label or f.var}: {f.value}").send()
229 else:
230 if f.type == "list-multi":
231 msg.reply(
232 "Multiple selection allowed, use new lines as a separator, ie, "
233 "one selected item per line. To select no item, reply with a space "
234 "(the punctuation)."
235 ).send()
236 if f.options:
237 for o in f.options:
238 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send()
239 if f.value:
240 msg.reply(f"Default: {f.value}").send()
241 if f.type == "boolean":
242 msg.reply("yes: " + self.__make_uri("yes")).send()
243 msg.reply("no: " + self.__make_uri("no")).send()
245 ans = await self.xmpp.input(
246 msg.get_from(), (f.label or f.var) + "? (or 'abort')"
247 )
248 if ans.lower() == "abort":
249 return await self._handle_result("Command aborted", msg, session)
250 if f.type == "boolean":
251 if ans.lower() == "yes":
252 ans = "true"
253 else:
254 ans = "false"
256 if f.type.endswith("multi"):
257 choices = [] if ans == " " else ans.split("\n")
258 form_values[f.var] = f.validate(choices)
259 else:
260 form_values[f.var] = f.validate(ans)
261 result = await self.__wrap_handler(
262 msg,
263 result.handler,
264 form_values,
265 session,
266 msg.get_from(),
267 *result.handler_args,
268 **result.handler_kwargs,
269 )
270 return await self._handle_result(result, msg, session)
272 @staticmethod
273 async def __wrap_handler(msg, f: Union[Callable, functools.partial], *a, **k):
274 try:
275 if inspect.iscoroutinefunction(f):
276 return await f(*a, **k)
277 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
278 return await f(*a, **k)
279 else:
280 return f(*a, **k)
281 except Exception as e:
282 log.debug("Error in %s", f, exc_info=e)
283 reply = msg.reply()
284 reply["body"] = f"Error: {e}"
285 reply.send()
287 def _handle_help(self, msg: Message, *rest) -> None:
288 if len(rest) == 0:
289 reply = msg.reply()
290 reply["body"] = self._help(msg.get_from())
291 reply.send()
292 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
293 reply = msg.reply()
294 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
295 reply.send()
296 else:
297 self._not_found(msg, str(rest))
299 def _help(self, mfrom: JID):
300 session = self.xmpp.get_session_from_jid(mfrom)
302 msg = "Available commands:"
303 for c in sorted(
304 self._commands.values(),
305 key=lambda co: (
306 (
307 co.CATEGORY
308 if isinstance(co.CATEGORY, str)
309 else (
310 co.CATEGORY.name
311 if isinstance(co.CATEGORY, CommandCategory)
312 else ""
313 )
314 ),
315 co.CHAT_COMMAND,
316 ),
317 ):
318 try:
319 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
320 except XMPPError:
321 continue
322 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
323 return msg
325 def _not_found(self, msg: Message, word: str):
326 e = self.UNKNOWN.format(word)
327 msg.reply(e).send()
328 raise XMPPError("item-not-found", e)
331def percent_encode(jid: JID) -> str:
332 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore
335log = logging.getLogger(__name__)