Coverage for slidge / command / chat_command.py: 71%
231 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 05:07 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import inspect
7import logging
8from collections.abc import Awaitable, Callable
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 Literal,
13 Never,
14 ParamSpec,
15 TypeVar,
16 cast,
17 overload,
18)
19from urllib.parse import quote as url_quote
21from slixmpp import JID, CoroutineCallback, Message, StanzaPath
22from slixmpp.exceptions import XMPPError
23from slixmpp.types import JidStr, MessageTypes
25from slidge.command.base import (
26 CommandResponseRecipientType,
27 CommandResponseSessionType,
28 ConfirmationRecipient,
29 ConfirmationSession,
30 FormRecipient,
31 FormSession,
32)
33from slidge.contact import LegacyContact
34from slidge.group import LegacyMUC
35from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession
37from . import Command, CommandResponseType, Confirmation, Form, TableResult
38from .categories import CommandCategory
40if TYPE_CHECKING:
41 from ..core.gateway import BaseGateway
43T = TypeVar("T")
44P = ParamSpec("P")
47class ChatCommandProvider:
48 UNKNOWN = "Wut? I don't know that command: {}"
50 def __init__(self, xmpp: "BaseGateway") -> None:
51 self.xmpp = xmpp
52 self._keywords = list[str]()
53 self._commands: dict[str, Command[AnySession]] = {}
54 self._input_futures = dict[str, asyncio.Future[str]]()
55 self.xmpp.register_handler(
56 CoroutineCallback(
57 "chat_command_handler",
58 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
59 self._handle_message, # type: ignore
60 )
61 )
63 def register(self, command: Command[AnySession]) -> None:
64 """
65 Register a command to be used via chat messages with the gateway
67 Plugins should not call this, any class subclassing Command should be
68 automatically added by slidge core.
70 :param command: the new command
71 """
72 t = command.CHAT_COMMAND
73 if t in self._commands:
74 raise RuntimeError("There is already a command triggered by '%s'", t)
75 self._commands[t] = command
77 @overload
78 async def input(self, jid: JidStr, text: str | None = None) -> str: ...
80 @overload
81 async def input(
82 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ...
83 ) -> asyncio.Future[str]: ...
85 @overload
86 async def input(
87 self,
88 jid: JidStr,
89 text: str | None = None,
90 *,
91 mtype: MessageTypes = "chat",
92 timeout: int = 60,
93 blocking: Literal[True] = True,
94 **msg_kwargs: Any, # noqa:ANN401
95 ) -> str: ...
97 async def input(
98 self,
99 jid: JidStr,
100 text: str | None = None,
101 *,
102 mtype: MessageTypes = "chat",
103 timeout: int = 60,
104 blocking: bool = True,
105 **msg_kwargs: Any,
106 ) -> str | asyncio.Future[str]:
107 """
108 Request arbitrary user input using a simple chat message, and await the result.
110 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
111 to directly target a user.
113 NB: When using this, the next message that the user sent to the component will
114 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
115 Await the coroutine to get its content.
117 :param jid: The JID we want input from
118 :param text: A prompt to display for the user
119 :param mtype: Message type
120 :param timeout:
121 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
122 is returned instead of a str
123 :return: The user's reply
124 """
125 jid = JID(jid)
126 if text is not None:
127 self.xmpp.send_message(
128 mto=jid,
129 mbody=text,
130 mtype=mtype,
131 mfrom=self.xmpp.boundjid.bare,
132 **msg_kwargs,
133 )
134 f: asyncio.Future[str] = asyncio.get_event_loop().create_future()
135 self._input_futures[jid.bare] = f
136 if not blocking:
137 return f
138 try:
139 await asyncio.wait_for(f, timeout)
140 except TimeoutError:
141 self.xmpp.send_message(
142 mto=jid,
143 mbody="You took too much time to reply",
144 mtype=mtype,
145 mfrom=self.xmpp.boundjid.bare,
146 )
147 del self._input_futures[jid.bare]
148 raise XMPPError("remote-server-timeout", "You took too much time to reply")
150 return f.result()
152 async def _handle_message(self, msg: Message) -> None:
153 if not msg["body"]:
154 return
156 if not msg.get_from().node:
157 return # ignore component and server messages
159 f = self._input_futures.pop(msg.get_from().bare, None)
160 if f is not None:
161 f.set_result(msg["body"])
162 return
164 c = msg["body"]
165 first_word, *rest = c.split(" ")
166 first_word = first_word.lower()
168 if first_word == "help":
169 return self._handle_help(msg, *rest)
171 if first_word in ("contact", "room"):
172 return await self._handle_recipient(first_word, msg, *rest)
174 mfrom = msg.get_from()
176 command = self._commands.get(first_word)
177 if command is None:
178 self._not_found(msg, first_word)
179 return
181 try:
182 session = command.raise_if_not_authorized(mfrom)
183 except XMPPError as e:
184 reply = msg.reply()
185 reply["body"] = e.text
186 reply.send()
187 raise
189 result: CommandResponseSessionType[Any] = await self.__wrap_handler(
190 msg, command.run, session, mfrom, *rest
191 )
192 self.xmpp.delivery_receipt.ack(msg)
193 await self._handle_result(result, msg, session)
195 def __make_uri(self, body: str) -> str:
196 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
198 async def _handle_result(
199 self,
200 result: CommandResponseSessionType[Any] | CommandResponseRecipientType[Any],
201 msg: Message,
202 session: "AnySession | None",
203 recipient: AnyRecipient | None = None,
204 ) -> CommandResponseSessionType[Any] | CommandResponseRecipientType[Any]:
205 if isinstance(result, str) or result is None:
206 reply = msg.reply()
207 reply["body"] = result or "End of command."
208 reply.send()
209 return None
211 if isinstance(result, Form):
212 if recipient is None:
213 result = cast(FormSession[AnySession], result)
214 else:
215 result = cast(FormRecipient[AnyRecipient], result)
216 try:
217 return await self.__handle_form( # type:ignore[return-value]
218 result, msg, session, recipient=recipient
219 )
220 except XMPPError as e:
221 if (
222 result.timeout_handler is None
223 or e.condition != "remote-server-timeout"
224 ):
225 raise e
226 return result.timeout_handler()
228 if isinstance(result, Confirmation):
229 yes_or_no = await self.input(msg.get_from(), result.prompt)
230 if not yes_or_no.lower().startswith("y"):
231 reply = msg.reply()
232 reply["body"] = "Canceled"
233 reply.send()
234 return None
235 if recipient is None:
236 result = cast(ConfirmationSession[AnySession], result)
237 result = await self.__wrap_handler(
238 msg,
239 result.handler,
240 session,
241 msg.get_from(),
242 *result.handler_args,
243 **result.handler_kwargs,
244 )
245 else:
246 result = cast(ConfirmationRecipient[AnyRecipient], result)
247 result = await self.__wrap_handler(
248 msg,
249 result.handler,
250 recipient,
251 *result.handler_args,
252 **result.handler_kwargs,
253 )
254 return await self._handle_result(result, msg, session, recipient=recipient)
256 if isinstance(result, TableResult):
257 if len(result.items) == 0:
258 msg.reply("Empty results").send()
259 return None
261 body = result.description + "\n"
262 for item in result.items:
263 for f in result.fields:
264 if f.type == "jid-single":
265 j = JID(item[f.var])
266 value = f"xmpp:{percent_encode(j)}"
267 if result.jids_are_mucs:
268 value += "?join"
269 else:
270 value = item[f.var] # type:ignore
271 body += f"\n{f.label or f.var}: {value}"
272 msg.reply(body).send()
274 raise RuntimeError
276 async def __handle_form(
277 self,
278 result: Form,
279 msg: Message,
280 session: AnySession | None,
281 recipient: AnyRecipient | None = None,
282 ) -> CommandResponseType:
283 form_values = {}
284 for t in result.title, result.instructions:
285 if t:
286 msg.reply(t).send()
287 for f in result.fields:
288 if f.type == "fixed":
289 msg.reply(f"{f.label or f.var}: {f.value}").send()
290 else:
291 if f.type == "list-multi":
292 msg.reply(
293 "Multiple selection allowed, use new lines as a separator, ie, "
294 "one selected item per line. To select no item, reply with a space "
295 "(the punctuation)."
296 ).send()
297 if f.options:
298 for o in f.options:
299 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send()
300 if f.value:
301 msg.reply(f"Default: {f.value}").send()
302 if f.type == "boolean":
303 msg.reply("yes: " + self.__make_uri("yes")).send()
304 msg.reply("no: " + self.__make_uri("no")).send()
306 ans = await self.xmpp.input(
307 msg.get_from(),
308 (f.label or f.var) + "? (or 'abort')",
309 mtype="chat",
310 )
311 if ans.lower() == "abort":
312 return await self._handle_result("Command aborted", msg, session)
313 if f.type == "boolean":
314 if ans.lower() == "yes":
315 ans = "true"
316 else:
317 ans = "false"
319 if f.type.endswith("multi"):
320 choices = [] if ans == " " else ans.split("\n")
321 form_values[f.var] = f.validate(choices)
322 else:
323 form_values[f.var] = f.validate(ans)
324 if recipient is None:
325 new_result = await self.__wrap_handler(
326 msg,
327 result.handler,
328 form_values,
329 session,
330 msg.get_from(),
331 *result.handler_args,
332 **result.handler_kwargs,
333 )
334 new_result = cast(CommandResponseSessionType[Any], new_result)
335 else:
336 new_result = await self.__wrap_handler(
337 msg,
338 result.handler,
339 recipient,
340 form_values,
341 *result.handler_args,
342 **result.handler_kwargs,
343 )
344 new_result = cast(CommandResponseRecipientType[Any], new_result)
346 return await self._handle_result(new_result, msg, session, recipient=recipient)
348 @staticmethod
349 async def __wrap_handler(
350 msg: Message,
351 f: Callable[P, Awaitable[T] | T],
352 *a: P.args,
353 **k: P.kwargs,
354 ) -> T | None:
355 try:
356 if inspect.iscoroutinefunction(f):
357 return await f(*a, **k) # type:ignore[no-any-return]
358 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
359 return await f(*a, **k) # type:ignore[misc,no-any-return]
360 else:
361 return f(*a, **k) # type:ignore[return-value]
362 except Exception as e:
363 log.debug("Error in %s", f, exc_info=e)
364 reply = msg.reply()
365 reply["body"] = f"Error: {e}"
366 reply.send()
367 return None
369 def _handle_help(self, msg: Message, *rest: str) -> None:
370 if len(rest) == 0:
371 reply = msg.reply()
372 reply["body"] = self._help(msg.get_from())
373 reply.send()
374 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
375 reply = msg.reply()
376 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
377 reply.send()
378 else:
379 self._not_found(msg, str(rest))
381 def _help(self, mfrom: JID) -> str:
382 session = self.xmpp.get_session_from_jid(mfrom)
384 msg = "Available commands:"
385 for c in sorted(
386 self._commands.values(),
387 key=lambda co: (
388 (
389 co.CATEGORY
390 if isinstance(co.CATEGORY, str)
391 else (
392 co.CATEGORY.name
393 if isinstance(co.CATEGORY, CommandCategory)
394 else ""
395 )
396 ),
397 co.CHAT_COMMAND,
398 ),
399 ):
400 try:
401 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
402 except XMPPError:
403 continue
404 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
405 return msg
407 def _not_found(self, msg: Message, word: str) -> Never:
408 e = self.UNKNOWN.format(word)
409 msg.reply(e).send()
410 raise XMPPError("item-not-found", e)
412 async def _handle_recipient(
413 self, recipient_str: Literal["contact", "room"], msg: Message, *args: str
414 ) -> None:
415 session = self.xmpp.get_session_from_jid(msg.get_from())
417 recipient_cls = LegacyContact if recipient_str == "contact" else LegacyMUC
419 if session is None:
420 raise XMPPError("subscription-required")
422 if len(args) == 0 or args[0] == "help":
423 self.xmpp.delivery_receipt.ack(msg)
424 self._help_recipient(msg, recipient_cls)
425 return
427 if len(args) == 1:
428 self._help_recipient(msg, recipient_cls)
429 raise XMPPError(
430 "bad-request",
431 f"Contact commands require at least two parameters: {recipient_str}_jid_username and command_name",
432 )
434 jid_username, command_name, *rest = args
436 command = recipient_cls.commands_chat.get(command_name)
437 if command is None:
438 raise XMPPError("item-not-found")
440 if recipient_cls is LegacyContact:
441 legacy_id = await session.contacts.jid_username_to_legacy_id(jid_username)
442 recipient = await session.contacts.by_legacy_id(legacy_id)
443 else:
444 legacy_id = await session.bookmarks.jid_username_to_legacy_id(jid_username)
445 recipient = await session.bookmarks.by_legacy_id(legacy_id)
447 result = await self.__wrap_handler(msg, command.run, recipient, *rest) # type:ignore[arg-type,func-returns-value]
448 self.xmpp.delivery_receipt.ack(msg)
449 await self._handle_result(result, msg, session, recipient)
451 def _help_recipient(
452 self, msg: Message, recipient_cls: type[AnyContact | AnyMUC]
453 ) -> None:
454 msg.reply(
455 "Available commands:\n"
456 + "\n".join(
457 f"{co.CHAT_COMMAND} ({co.NAME}): {co.HELP}"
458 for co in recipient_cls.commands_chat.values()
459 )
460 ).send()
463def percent_encode(jid: JID) -> str:
464 return f"{url_quote(jid.user)}@{jid.server}"
467log = logging.getLogger(__name__)