Coverage for slidge/command/chat_command.py: 58%
175 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-04 08:17 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import functools
7import logging
8from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, overload
9from urllib.parse import quote as url_quote
11from slixmpp import JID, CoroutineCallback, Message, StanzaPath
12from slixmpp.exceptions import XMPPError
13from slixmpp.types import JidStr, MessageTypes
15from . import Command, CommandResponseType, Confirmation, Form, TableResult
16from .categories import CommandCategory
18if TYPE_CHECKING:
19 from ..core.gateway import BaseGateway
22class ChatCommandProvider:
23 UNKNOWN = "Wut? I don't know that command: {}"
25 def __init__(self, xmpp: "BaseGateway") -> None:
26 self.xmpp = xmpp
27 self._keywords = list[str]()
28 self._commands: dict[str, Command] = {}
29 self._input_futures = dict[str, asyncio.Future[str]]()
30 self.xmpp.register_handler(
31 CoroutineCallback(
32 "chat_command_handler",
33 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
34 self._handle_message, # type: ignore
35 )
36 )
38 def register(self, command: Command):
39 """
40 Register a command to be used via chat messages with the gateway
42 Plugins should not call this, any class subclassing Command should be
43 automatically added by slidge core.
45 :param command: the new command
46 """
47 t = command.CHAT_COMMAND
48 if t in self._commands:
49 raise RuntimeError("There is already a command triggered by '%s'", t)
50 self._commands[t] = command
52 @overload
53 async def input(self, jid: JidStr, text: Optional[str] = None) -> str: ...
55 @overload
56 async def input(
57 self, jid: JidStr, text: Optional[str] = None, *, blocking: Literal[False] = ...
58 ) -> asyncio.Future[str]: ...
60 @overload
61 async def input(
62 self,
63 jid: JidStr,
64 text: str | None = None,
65 *,
66 mtype: MessageTypes = "chat",
67 timeout: int = 60,
68 blocking: Literal[True] = True,
69 **msg_kwargs: Any,
70 ) -> str: ...
72 async def input(
73 self,
74 jid: JidStr,
75 text: str | None = None,
76 *,
77 mtype: MessageTypes = "chat",
78 timeout: int = 60,
79 blocking: bool = True,
80 **msg_kwargs: Any,
81 ) -> str | asyncio.Future[str]:
82 """
83 Request arbitrary user input using a simple chat message, and await the result.
85 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
86 to directly target a user.
88 NB: When using this, the next message that the user sent to the component will
89 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
90 Await the coroutine to get its content.
92 :param jid: The JID we want input from
93 :param text: A prompt to display for the user
94 :param mtype: Message type
95 :param timeout:
96 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
97 is returned instead of a str
98 :return: The user's reply
99 """
100 jid = JID(jid)
101 if text is not None:
102 self.xmpp.send_message(
103 mto=jid,
104 mbody=text,
105 mtype=mtype,
106 mfrom=self.xmpp.boundjid.bare,
107 **msg_kwargs,
108 )
109 f = asyncio.get_event_loop().create_future()
110 self._input_futures[jid.bare] = f
111 if not blocking:
112 return f
113 try:
114 await asyncio.wait_for(f, timeout)
115 except asyncio.TimeoutError:
116 self.xmpp.send_message(
117 mto=jid,
118 mbody="You took too much time to reply",
119 mtype=mtype,
120 mfrom=self.xmpp.boundjid.bare,
121 )
122 del self._input_futures[jid.bare]
123 raise XMPPError("remote-server-timeout", "You took too much time to reply")
125 return f.result()
127 async def _handle_message(self, msg: Message):
128 if not msg["body"]:
129 return
131 if not msg.get_from().node:
132 return # ignore component and server messages
134 f = self._input_futures.pop(msg.get_from().bare, None)
135 if f is not None:
136 f.set_result(msg["body"])
137 return
139 c = msg["body"]
140 first_word, *rest = c.split(" ")
141 first_word = first_word.lower()
143 if first_word == "help":
144 return self._handle_help(msg, *rest)
146 mfrom = msg.get_from()
148 command = self._commands.get(first_word)
149 if command is None:
150 return self._not_found(msg, first_word)
152 try:
153 session = command.raise_if_not_authorized(mfrom)
154 except XMPPError as e:
155 reply = msg.reply()
156 reply["body"] = e.text
157 reply.send()
158 raise
160 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest)
161 self.xmpp.delivery_receipt.ack(msg)
162 return await self._handle_result(result, msg, session)
164 def __make_uri(self, body: str) -> str:
165 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
167 async def _handle_result(self, result: CommandResponseType, msg: Message, session):
168 if isinstance(result, str) or result is None:
169 reply = msg.reply()
170 reply["body"] = result or "End of command."
171 reply.send()
172 return
174 if isinstance(result, Form):
175 form_values = {}
176 for t in result.title, result.instructions:
177 if t:
178 msg.reply(t).send()
179 for f in result.fields:
180 if f.type == "fixed":
181 msg.reply(f"{f.label or f.var}: {f.value}").send()
182 else:
183 if f.type == "list-multi":
184 msg.reply(
185 "Multiple selection allowed, use new lines as a separator, ie, "
186 "one selected item per line. To select no item, reply with a space "
187 "(the punctuation)."
188 ).send()
189 if f.options:
190 for o in f.options:
191 msg.reply(
192 f"{o['label']}: {self.__make_uri(o['value'])}"
193 ).send()
194 if f.value:
195 msg.reply(f"Default: {f.value}").send()
196 if f.type == "boolean":
197 msg.reply("yes: " + self.__make_uri("yes")).send()
198 msg.reply("no: " + self.__make_uri("no")).send()
200 ans = await self.xmpp.input(
201 msg.get_from(), (f.label or f.var) + "? (or 'abort')"
202 )
203 if ans.lower() == "abort":
204 return await self._handle_result(
205 "Command aborted", msg, session
206 )
207 if f.type == "boolean":
208 if ans.lower() == "yes":
209 ans = "true"
210 else:
211 ans = "false"
213 if f.type.endswith("multi"):
214 choices = [] if ans == " " else ans.split("\n")
215 form_values[f.var] = f.validate(choices)
216 else:
217 form_values[f.var] = f.validate(ans)
218 result = await self.__wrap_handler(
219 msg,
220 result.handler,
221 form_values,
222 session,
223 msg.get_from(),
224 *result.handler_args,
225 **result.handler_kwargs,
226 )
227 return await self._handle_result(result, msg, session)
229 if isinstance(result, Confirmation):
230 yes_or_no = await self.input(msg.get_from(), result.prompt)
231 if not yes_or_no.lower().startswith("y"):
232 reply = msg.reply()
233 reply["body"] = "Canceled"
234 reply.send()
235 return
236 result = await self.__wrap_handler(
237 msg,
238 result.handler,
239 session,
240 msg.get_from(),
241 *result.handler_args,
242 **result.handler_kwargs,
243 )
244 return await self._handle_result(result, msg, session)
246 if isinstance(result, TableResult):
247 if len(result.items) == 0:
248 msg.reply("Empty results").send()
249 return
251 body = result.description + "\n"
252 for item in result.items:
253 for f in result.fields:
254 if f.type == "jid-single":
255 j = JID(item[f.var])
256 value = f"xmpp:{percent_encode(j)}"
257 if result.jids_are_mucs:
258 value += "?join"
259 else:
260 value = item[f.var] # type:ignore
261 body += f"\n{f.label or f.var}: {value}"
262 msg.reply(body).send()
264 @staticmethod
265 async def __wrap_handler(msg, f: Union[Callable, functools.partial], *a, **k):
266 try:
267 if asyncio.iscoroutinefunction(f):
268 return await f(*a, **k)
269 elif hasattr(f, "func") and asyncio.iscoroutinefunction(f.func):
270 return await f(*a, **k)
271 else:
272 return f(*a, **k)
273 except Exception as e:
274 log.debug("Error in %s", f, exc_info=e)
275 reply = msg.reply()
276 reply["body"] = f"Error: {e}"
277 reply.send()
279 def _handle_help(self, msg: Message, *rest) -> None:
280 if len(rest) == 0:
281 reply = msg.reply()
282 reply["body"] = self._help(msg.get_from())
283 reply.send()
284 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
285 reply = msg.reply()
286 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
287 reply.send()
288 else:
289 self._not_found(msg, str(rest))
291 def _help(self, mfrom: JID):
292 session = self.xmpp.get_session_from_jid(mfrom)
294 msg = "Available commands:"
295 for c in sorted(
296 self._commands.values(),
297 key=lambda co: (
298 (
299 co.CATEGORY
300 if isinstance(co.CATEGORY, str)
301 else (
302 co.CATEGORY.name
303 if isinstance(co.CATEGORY, CommandCategory)
304 else ""
305 )
306 ),
307 co.CHAT_COMMAND,
308 ),
309 ):
310 try:
311 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
312 except XMPPError:
313 continue
314 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
315 return msg
317 def _not_found(self, msg: Message, word: str):
318 e = self.UNKNOWN.format(word)
319 msg.reply(e).send()
320 raise XMPPError("item-not-found", e)
323def percent_encode(jid: JID) -> str:
324 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore
327log = logging.getLogger(__name__)