Coverage for slidge / command / chat_command.py: 72%
232 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-20 19:56 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import inspect
7import logging
8from collections.abc import Awaitable, Callable
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 Generic,
13 Literal,
14 Never,
15 ParamSpec,
16 TypeVar,
17 cast,
18 overload,
19)
20from urllib.parse import quote as url_quote
22from slixmpp import JID, CoroutineCallback, Message, StanzaPath
23from slixmpp.exceptions import XMPPError
24from slixmpp.types import JidStr, MessageTypes
26from slidge.command.base import (
27 CommandResponseRecipientType,
28 CommandResponseSessionType,
29 ConfirmationRecipient,
30 ConfirmationSession,
31 FormRecipient,
32 FormSession,
33)
34from slidge.contact import LegacyContact
35from slidge.group import LegacyMUC
36from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession
38from . import Command, CommandResponseType, Confirmation, Form, TableResult
39from .categories import CommandCategory
41if TYPE_CHECKING:
42 from ..core.gateway import BaseGateway
44GatewayType = TypeVar("GatewayType", bound="BaseGateway[Any]")
45T = TypeVar("T")
46P = ParamSpec("P")
49class ChatCommandProvider(Generic[GatewayType]):
50 UNKNOWN = "Wut? I don't know that command: {}"
51 xmpp: GatewayType
53 def __init__(self, xmpp: GatewayType) -> None:
54 self.xmpp = xmpp
55 self._keywords = list[str]()
56 self._commands: dict[str, Command[AnySession]] = {}
57 self._input_futures = dict[str, asyncio.Future[str]]()
58 self.xmpp.register_handler(
59 CoroutineCallback(
60 "chat_command_handler",
61 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
62 self._handle_message, # type: ignore
63 )
64 )
66 def register(self, command: Command[AnySession]) -> None:
67 """
68 Register a command to be used via chat messages with the gateway
70 Plugins should not call this, any class subclassing Command should be
71 automatically added by slidge core.
73 :param command: the new command
74 """
75 t = command.CHAT_COMMAND
76 if t in self._commands:
77 raise RuntimeError("There is already a command triggered by '%s'", t)
78 self._commands[t] = command
80 @overload
81 async def input(self, jid: JidStr, text: str | None = None) -> str: ...
83 @overload
84 async def input(
85 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ...
86 ) -> asyncio.Future[str]: ...
88 @overload
89 async def input(
90 self,
91 jid: JidStr,
92 text: str | None = None,
93 *,
94 mtype: MessageTypes = "chat",
95 timeout: int = 60,
96 blocking: Literal[True] = True,
97 **msg_kwargs: Any, # noqa:ANN401
98 ) -> str: ...
100 async def input(
101 self,
102 jid: JidStr,
103 text: str | None = None,
104 *,
105 mtype: MessageTypes = "chat",
106 timeout: int = 60,
107 blocking: bool = True,
108 **msg_kwargs: Any,
109 ) -> str | asyncio.Future[str]:
110 """
111 Request arbitrary user input using a simple chat message, and await the result.
113 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
114 to directly target a user.
116 NB: When using this, the next message that the user sent to the component will
117 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
118 Await the coroutine to get its content.
120 :param jid: The JID we want input from
121 :param text: A prompt to display for the user
122 :param mtype: Message type
123 :param timeout:
124 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
125 is returned instead of a str
126 :return: The user's reply
127 """
128 jid = JID(jid)
129 if text is not None:
130 self.xmpp.send_message(
131 mto=jid,
132 mbody=text,
133 mtype=mtype,
134 mfrom=self.xmpp.boundjid.bare,
135 **msg_kwargs,
136 )
137 f: asyncio.Future[str] = asyncio.get_event_loop().create_future()
138 self._input_futures[jid.bare] = f
139 if not blocking:
140 return f
141 try:
142 await asyncio.wait_for(f, timeout)
143 except TimeoutError:
144 self.xmpp.send_message(
145 mto=jid,
146 mbody="You took too much time to reply",
147 mtype=mtype,
148 mfrom=self.xmpp.boundjid.bare,
149 )
150 del self._input_futures[jid.bare]
151 raise XMPPError("remote-server-timeout", "You took too much time to reply")
153 return f.result()
155 async def _handle_message(self, msg: Message) -> None:
156 if not msg["body"]:
157 return
159 if not msg.get_from().node:
160 return # ignore component and server messages
162 f = self._input_futures.pop(msg.get_from().bare, None)
163 if f is not None:
164 f.set_result(msg["body"])
165 return
167 c = msg["body"]
168 first_word, *rest = c.split(" ")
169 first_word = first_word.lower()
171 if first_word == "help":
172 return self._handle_help(msg, *rest)
174 if first_word in ("contact", "room"):
175 return await self._handle_recipient(first_word, msg, *rest)
177 mfrom = msg.get_from()
179 command = self._commands.get(first_word)
180 if command is None:
181 self._not_found(msg, first_word)
182 return
184 try:
185 session = command.raise_if_not_authorized(mfrom)
186 except XMPPError as e:
187 reply = msg.reply()
188 reply["body"] = e.text
189 reply.send()
190 raise
192 result: CommandResponseSessionType[Any] = await self.__wrap_handler(
193 msg, command.run, session, mfrom, *rest
194 )
195 self.xmpp.delivery_receipt.ack(msg)
196 await self._handle_result(result, msg, session)
198 def __make_uri(self, body: str) -> str:
199 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
201 async def _handle_result(
202 self,
203 result: CommandResponseSessionType[Any] | CommandResponseRecipientType[Any],
204 msg: Message,
205 session: "AnySession | None",
206 recipient: AnyRecipient | None = None,
207 ) -> CommandResponseSessionType[Any] | CommandResponseRecipientType[Any]:
208 if isinstance(result, str) or result is None:
209 reply = msg.reply()
210 reply["body"] = result or "End of command."
211 reply.send()
212 return None
214 if isinstance(result, Form):
215 if recipient is None:
216 result = cast(FormSession[AnySession], result)
217 else:
218 result = cast(FormRecipient[AnyRecipient], result)
219 try:
220 return await self.__handle_form( # type:ignore[return-value]
221 result, msg, session, recipient=recipient
222 )
223 except XMPPError as e:
224 if (
225 result.timeout_handler is None
226 or e.condition != "remote-server-timeout"
227 ):
228 raise e
229 return result.timeout_handler()
231 if isinstance(result, Confirmation):
232 yes_or_no = await self.input(msg.get_from(), result.prompt)
233 if not yes_or_no.lower().startswith("y"):
234 reply = msg.reply()
235 reply["body"] = "Canceled"
236 reply.send()
237 return None
238 if recipient is None:
239 result = cast(ConfirmationSession[AnySession], result)
240 result = await self.__wrap_handler(
241 msg,
242 result.handler,
243 session,
244 msg.get_from(),
245 *result.handler_args,
246 **result.handler_kwargs,
247 )
248 else:
249 result = cast(ConfirmationRecipient[AnyRecipient], result)
250 result = await self.__wrap_handler(
251 msg,
252 result.handler,
253 recipient,
254 *result.handler_args,
255 **result.handler_kwargs,
256 )
257 return await self._handle_result(result, msg, session, recipient=recipient)
259 if isinstance(result, TableResult):
260 if len(result.items) == 0:
261 msg.reply("Empty results").send()
262 return None
264 body = result.description + "\n"
265 for item in result.items:
266 for f in result.fields:
267 if f.type == "jid-single":
268 j = JID(item[f.var])
269 value = f"xmpp:{percent_encode(j)}"
270 if result.jids_are_mucs:
271 value += "?join"
272 else:
273 value = item[f.var] # type:ignore
274 body += f"\n{f.label or f.var}: {value}"
275 msg.reply(body).send()
276 return None
278 raise RuntimeError
280 async def __handle_form(
281 self,
282 result: Form,
283 msg: Message,
284 session: AnySession | None,
285 recipient: AnyRecipient | None = None,
286 ) -> CommandResponseType:
287 form_values = {}
288 for t in result.title, result.instructions:
289 if t:
290 msg.reply(t).send()
291 for f in result.fields:
292 if f.type == "fixed":
293 msg.reply(f"{f.label or f.var}: {f.value}").send()
294 else:
295 if f.type == "list-multi":
296 msg.reply(
297 "Multiple selection allowed, use new lines as a separator, ie, "
298 "one selected item per line. To select no item, reply with a space "
299 "(the punctuation)."
300 ).send()
301 if f.options:
302 for o in f.options:
303 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send()
304 if f.value:
305 msg.reply(f"Default: {f.value}").send()
306 if f.type == "boolean":
307 msg.reply("yes: " + self.__make_uri("yes")).send()
308 msg.reply("no: " + self.__make_uri("no")).send()
310 ans = await self.xmpp.input(
311 msg.get_from(),
312 (f.label or f.var) + "? (or 'abort')",
313 mtype="chat",
314 )
315 if ans.lower() == "abort":
316 return await self._handle_result("Command aborted", msg, session)
317 if f.type == "boolean":
318 ans = "true" if ans.lower() == "yes" else "false"
320 if f.type.endswith("multi"):
321 choices = [] if ans == " " else ans.split("\n")
322 form_values[f.var] = f.validate(choices)
323 else:
324 form_values[f.var] = f.validate(ans)
325 if recipient is None:
326 new_result = await self.__wrap_handler(
327 msg,
328 result.handler,
329 form_values,
330 session,
331 msg.get_from(),
332 *result.handler_args,
333 **result.handler_kwargs,
334 )
335 new_result = cast(CommandResponseSessionType[Any], new_result)
336 else:
337 new_result = await self.__wrap_handler(
338 msg,
339 result.handler,
340 recipient,
341 form_values,
342 *result.handler_args,
343 **result.handler_kwargs,
344 )
345 new_result = cast(CommandResponseRecipientType[Any], new_result)
347 return await self._handle_result(new_result, msg, session, recipient=recipient)
349 @staticmethod
350 async def __wrap_handler(
351 msg: Message,
352 f: Callable[P, Awaitable[T] | T],
353 *a: P.args,
354 **k: P.kwargs,
355 ) -> T | None:
356 try:
357 if inspect.iscoroutinefunction(f):
358 return await f(*a, **k) # type:ignore[no-any-return]
359 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
360 return await f(*a, **k) # type:ignore[misc,no-any-return]
361 else:
362 return f(*a, **k) # type:ignore[return-value]
363 except Exception as e:
364 log.debug("Error in %s", f, exc_info=e)
365 reply = msg.reply()
366 reply["body"] = f"Error: {e}"
367 reply.send()
368 return None
370 def _handle_help(self, msg: Message, *rest: str) -> None:
371 if len(rest) == 0:
372 reply = msg.reply()
373 reply["body"] = self._help(msg.get_from())
374 reply.send()
375 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
376 reply = msg.reply()
377 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
378 reply.send()
379 else:
380 self._not_found(msg, str(rest))
382 def _help(self, mfrom: JID) -> str:
383 session = self.xmpp.get_session_from_jid(mfrom)
385 msg = "Available commands:"
386 for c in sorted(
387 self._commands.values(),
388 key=lambda co: (
389 (
390 co.CATEGORY
391 if isinstance(co.CATEGORY, str)
392 else (
393 co.CATEGORY.name
394 if isinstance(co.CATEGORY, CommandCategory)
395 else ""
396 )
397 ),
398 co.CHAT_COMMAND,
399 ),
400 ):
401 try:
402 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
403 except XMPPError:
404 continue
405 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
406 return msg
408 def _not_found(self, msg: Message, word: str) -> Never:
409 e = self.UNKNOWN.format(word)
410 msg.reply(e).send()
411 raise XMPPError("item-not-found", e)
413 async def _handle_recipient(
414 self, recipient_str: Literal["contact", "room"], msg: Message, *args: str
415 ) -> None:
416 session = self.xmpp.get_session_from_jid(msg.get_from())
418 recipient_cls = LegacyContact if recipient_str == "contact" else LegacyMUC
420 if session is None:
421 raise XMPPError("subscription-required")
423 if len(args) == 0 or args[0] == "help":
424 self.xmpp.delivery_receipt.ack(msg)
425 self._help_recipient(msg, recipient_cls)
426 return
428 if len(args) == 1:
429 self._help_recipient(msg, recipient_cls)
430 raise XMPPError(
431 "bad-request",
432 f"Contact commands require at least two parameters: {recipient_str}_jid_username and command_name",
433 )
435 jid_username, command_name, *rest = args
437 command = recipient_cls.commands_chat.get(command_name)
438 if command is None:
439 raise XMPPError("item-not-found")
441 if recipient_cls is LegacyContact:
442 legacy_id = await session.contacts.jid_username_to_legacy_id(jid_username)
443 recipient = await session.contacts.by_legacy_id(legacy_id)
444 else:
445 legacy_id = await session.bookmarks.jid_username_to_legacy_id(jid_username)
446 recipient = await session.bookmarks.by_legacy_id(legacy_id)
448 result = await self.__wrap_handler(msg, command.run, recipient, *rest) # type:ignore[arg-type,func-returns-value]
449 self.xmpp.delivery_receipt.ack(msg)
450 await self._handle_result(result, msg, session, recipient)
452 def _help_recipient(
453 self, msg: Message, recipient_cls: type[AnyContact | AnyMUC]
454 ) -> None:
455 msg.reply(
456 "Available commands:\n"
457 + "\n".join(
458 f"{co.CHAT_COMMAND} ({co.NAME}): {co.HELP}"
459 for co in recipient_cls.commands_chat.values()
460 )
461 ).send()
464def percent_encode(jid: JID) -> str:
465 return f"{url_quote(jid.user)}@{jid.server}"
468log = logging.getLogger(__name__)