Coverage for slidge/command/chat_command.py: 58%
172 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import functools
7import logging
8from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload
9from urllib.parse import quote as url_quote
11from slixmpp import JID, CoroutineCallback, Message, StanzaPath
12from slixmpp.exceptions import XMPPError
13from slixmpp.types import JidStr, MessageTypes
15from . import Command, CommandResponseType, Confirmation, Form, TableResult
16from .categories import CommandCategory
18if TYPE_CHECKING:
19 from ..core.gateway import BaseGateway
22class ChatCommandProvider:
23 UNKNOWN = "Wut? I don't know that command: {}"
25 def __init__(self, xmpp: "BaseGateway"):
26 self.xmpp = xmpp
27 self._keywords = list[str]()
28 self._commands: dict[str, Command] = {}
29 self._input_futures = dict[str, asyncio.Future[str]]()
30 self.xmpp.register_handler(
31 CoroutineCallback(
32 "chat_command_handler",
33 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
34 self._handle_message, # type: ignore
35 )
36 )
38 def register(self, command: Command):
39 """
40 Register a command to be used via chat messages with the gateway
42 Plugins should not call this, any class subclassing Command should be
43 automatically added by slidge core.
45 :param command: the new command
46 """
47 t = command.CHAT_COMMAND
48 if t in self._commands:
49 raise RuntimeError("There is already a command triggered by '%s'", t)
50 self._commands[t] = command
52 @overload
53 async def input(
54 self, jid: JidStr, text: Optional[str], blocking: Literal[False]
55 ) -> asyncio.Future[str]: ...
57 @overload
58 async def input(
59 self,
60 jid: JidStr,
61 text: Optional[str],
62 mtype: MessageTypes = ...,
63 blocking: Literal[True] = ...,
64 ) -> str: ...
66 async def input(
67 self,
68 jid,
69 text=None,
70 mtype="chat",
71 timeout=60,
72 blocking=True,
73 **msg_kwargs,
74 ):
75 """
76 Request arbitrary user input using a simple chat message, and await the result.
78 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
79 to directly target a user.
81 NB: When using this, the next message that the user sent to the component will
82 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
83 Await the coroutine to get its content.
85 :param jid: The JID we want input from
86 :param text: A prompt to display for the user
87 :param mtype: Message type
88 :param timeout:
89 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
90 is returned instead of a str
91 :return: The user's reply
92 """
93 jid = JID(jid)
94 if text is not None:
95 self.xmpp.send_message(
96 mto=jid,
97 mbody=text,
98 mtype=mtype,
99 mfrom=self.xmpp.boundjid.bare,
100 **msg_kwargs,
101 )
102 f = asyncio.get_event_loop().create_future()
103 self._input_futures[jid.bare] = f
104 if not blocking:
105 return f
106 try:
107 await asyncio.wait_for(f, timeout)
108 except asyncio.TimeoutError:
109 self.xmpp.send_message(
110 mto=jid,
111 mbody="You took too much time to reply",
112 mtype=mtype,
113 mfrom=self.xmpp.boundjid.bare,
114 )
115 del self._input_futures[jid.bare]
116 raise XMPPError("remote-server-timeout", "You took too much time to reply")
118 return f.result()
120 async def _handle_message(self, msg: Message):
121 if not msg["body"]:
122 return
124 if not msg.get_from().node:
125 return # ignore component and server messages
127 f = self._input_futures.pop(msg.get_from().bare, None)
128 if f is not None:
129 f.set_result(msg["body"])
130 return
132 c = msg["body"]
133 first_word, *rest = c.split(" ")
134 first_word = first_word.lower()
136 if first_word == "help":
137 return self._handle_help(msg, *rest)
139 mfrom = msg.get_from()
141 command = self._commands.get(first_word)
142 if command is None:
143 return self._not_found(msg, first_word)
145 try:
146 session = command.raise_if_not_authorized(mfrom)
147 except XMPPError as e:
148 reply = msg.reply()
149 reply["body"] = e.text
150 reply.send()
151 raise
153 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest)
154 self.xmpp.delivery_receipt.ack(msg)
155 return await self._handle_result(result, msg, session)
157 def __make_uri(self, body: str) -> str:
158 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
160 async def _handle_result(self, result: CommandResponseType, msg: Message, session):
161 if isinstance(result, str) or result is None:
162 reply = msg.reply()
163 reply["body"] = result or "End of command."
164 reply.send()
165 return
167 if isinstance(result, Form):
168 form_values = {}
169 for t in result.title, result.instructions:
170 if t:
171 msg.reply(t).send()
172 for f in result.fields:
173 if f.type == "fixed":
174 msg.reply(f"{f.label or f.var}: {f.value}").send()
175 else:
176 if f.type == "list-multi":
177 msg.reply(
178 "Multiple selection allowed, use new lines as a separator, ie, "
179 "one selected item per line. To select no item, reply with a space "
180 "(the punctuation)."
181 ).send()
182 if f.options:
183 for o in f.options:
184 msg.reply(
185 f"{o['label']}: {self.__make_uri(o['value'])}"
186 ).send()
187 if f.value:
188 msg.reply(f"Default: {f.value}").send()
189 if f.type == "boolean":
190 msg.reply("yes: " + self.__make_uri("yes")).send()
191 msg.reply("no: " + self.__make_uri("no")).send()
193 ans = await self.xmpp.input(
194 msg.get_from(), (f.label or f.var) + "? (or 'abort')"
195 )
196 if ans.lower() == "abort":
197 return await self._handle_result(
198 "Command aborted", msg, session
199 )
200 if f.type == "boolean":
201 if ans.lower() == "yes":
202 ans = "true"
203 else:
204 ans = "false"
206 if f.type.endswith("multi"):
207 choices = [] if ans == " " else ans.split("\n")
208 form_values[f.var] = f.validate(choices)
209 else:
210 form_values[f.var] = f.validate(ans)
211 result = await self.__wrap_handler(
212 msg,
213 result.handler,
214 form_values,
215 session,
216 msg.get_from(),
217 *result.handler_args,
218 **result.handler_kwargs,
219 )
220 return await self._handle_result(result, msg, session)
222 if isinstance(result, Confirmation):
223 yes_or_no = await self.input(msg.get_from(), result.prompt)
224 if not yes_or_no.lower().startswith("y"):
225 reply = msg.reply()
226 reply["body"] = "Canceled"
227 reply.send()
228 return
229 result = await self.__wrap_handler(
230 msg,
231 result.handler,
232 session,
233 msg.get_from(),
234 *result.handler_args,
235 **result.handler_kwargs,
236 )
237 return await self._handle_result(result, msg, session)
239 if isinstance(result, TableResult):
240 if len(result.items) == 0:
241 msg.reply("Empty results").send()
242 return
244 body = result.description + "\n"
245 for item in result.items:
246 for f in result.fields:
247 if f.type == "jid-single":
248 j = JID(item[f.var])
249 value = f"xmpp:{percent_encode(j)}"
250 if result.jids_are_mucs:
251 value += "?join"
252 else:
253 value = item[f.var] # type:ignore
254 body += f"\n{f.label or f.var}: {value}"
255 msg.reply(body).send()
257 @staticmethod
258 async def __wrap_handler(msg, f: Union[Callable, functools.partial], *a, **k):
259 try:
260 if asyncio.iscoroutinefunction(f):
261 return await f(*a, **k)
262 elif hasattr(f, "func") and asyncio.iscoroutinefunction(f.func):
263 return await f(*a, **k)
264 else:
265 return f(*a, **k)
266 except Exception as e:
267 log.debug("Error in %s", f, exc_info=e)
268 reply = msg.reply()
269 reply["body"] = f"Error: {e}"
270 reply.send()
272 def _handle_help(self, msg: Message, *rest):
273 if len(rest) == 0:
274 reply = msg.reply()
275 reply["body"] = self._help(msg.get_from())
276 reply.send()
277 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
278 reply = msg.reply()
279 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
280 reply.send()
281 else:
282 self._not_found(msg, str(rest))
284 def _help(self, mfrom: JID):
285 msg = "Available commands:"
286 for c in sorted(
287 self._commands.values(),
288 key=lambda co: (
289 (
290 co.CATEGORY
291 if isinstance(co.CATEGORY, str)
292 else (
293 co.CATEGORY.name
294 if isinstance(co.CATEGORY, CommandCategory)
295 else ""
296 )
297 ),
298 co.CHAT_COMMAND,
299 ),
300 ):
301 try:
302 c.raise_if_not_authorized(mfrom)
303 except XMPPError:
304 continue
305 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
306 return msg
308 def _not_found(self, msg: Message, word: str):
309 e = self.UNKNOWN.format(word)
310 msg.reply(e).send()
311 raise XMPPError("item-not-found", e)
314def percent_encode(jid: JID):
315 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore
318log = logging.getLogger(__name__)