Coverage for slidge / command / chat_command.py: 71%
232 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-27 20:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-27 20:49 +0000
1# Handle slidge commands by exchanging chat messages with the gateway components.
3# Ad-hoc methods should provide a better UX, but some clients do not support them,
4# so this is mostly a fallback.
5import asyncio
6import inspect
7import logging
8from collections.abc import Awaitable, Callable
9from typing import (
10 TYPE_CHECKING,
11 Any,
12 Literal,
13 Never,
14 ParamSpec,
15 TypeVar,
16 cast,
17 overload,
18)
19from urllib.parse import quote as url_quote
21from slixmpp import JID, CoroutineCallback, Message, StanzaPath
22from slixmpp.exceptions import XMPPError
23from slixmpp.types import JidStr, MessageTypes
25from slidge.command.base import (
26 CommandResponseRecipientType,
27 CommandResponseSessionType,
28 ConfirmationRecipient,
29 ConfirmationSession,
30 FormRecipient,
31 FormSession,
32)
33from slidge.contact import LegacyContact
34from slidge.group import LegacyMUC
35from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession
37from . import Command, CommandResponseType, Confirmation, Form, TableResult
38from .categories import CommandCategory
40if TYPE_CHECKING:
41 from ..core.gateway import BaseGateway
43T = TypeVar("T")
44P = ParamSpec("P")
47class ChatCommandProvider:
48 UNKNOWN = "Wut? I don't know that command: {}"
50 def __init__(self, xmpp: "BaseGateway") -> None:
51 self.xmpp = xmpp
52 self._keywords = list[str]()
53 self._commands: dict[str, Command[AnySession]] = {}
54 self._input_futures = dict[str, asyncio.Future[str]]()
55 self.xmpp.register_handler(
56 CoroutineCallback(
57 "chat_command_handler",
58 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"),
59 self._handle_message, # type: ignore
60 )
61 )
63 def register(self, command: Command[AnySession]) -> None:
64 """
65 Register a command to be used via chat messages with the gateway
67 Plugins should not call this, any class subclassing Command should be
68 automatically added by slidge core.
70 :param command: the new command
71 """
72 t = command.CHAT_COMMAND
73 if t in self._commands:
74 raise RuntimeError("There is already a command triggered by '%s'", t)
75 self._commands[t] = command
77 @overload
78 async def input(self, jid: JidStr, text: str | None = None) -> str: ...
80 @overload
81 async def input(
82 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ...
83 ) -> asyncio.Future[str]: ...
85 @overload
86 async def input(
87 self,
88 jid: JidStr,
89 text: str | None = None,
90 *,
91 mtype: MessageTypes = "chat",
92 timeout: int = 60,
93 blocking: Literal[True] = True,
94 **msg_kwargs: Any, # noqa:ANN401
95 ) -> str: ...
97 async def input(
98 self,
99 jid: JidStr,
100 text: str | None = None,
101 *,
102 mtype: MessageTypes = "chat",
103 timeout: int = 60,
104 blocking: bool = True,
105 **msg_kwargs: Any,
106 ) -> str | asyncio.Future[str]:
107 """
108 Request arbitrary user input using a simple chat message, and await the result.
110 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input`
111 to directly target a user.
113 NB: When using this, the next message that the user sent to the component will
114 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted.
115 Await the coroutine to get its content.
117 :param jid: The JID we want input from
118 :param text: A prompt to display for the user
119 :param mtype: Message type
120 :param timeout:
121 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future`
122 is returned instead of a str
123 :return: The user's reply
124 """
125 jid = JID(jid)
126 if text is not None:
127 self.xmpp.send_message(
128 mto=jid,
129 mbody=text,
130 mtype=mtype,
131 mfrom=self.xmpp.boundjid.bare,
132 **msg_kwargs,
133 )
134 f: asyncio.Future[str] = asyncio.get_event_loop().create_future()
135 self._input_futures[jid.bare] = f
136 if not blocking:
137 return f
138 try:
139 await asyncio.wait_for(f, timeout)
140 except TimeoutError:
141 self.xmpp.send_message(
142 mto=jid,
143 mbody="You took too much time to reply",
144 mtype=mtype,
145 mfrom=self.xmpp.boundjid.bare,
146 )
147 del self._input_futures[jid.bare]
148 raise XMPPError("remote-server-timeout", "You took too much time to reply")
150 return f.result()
152 async def _handle_message(self, msg: Message) -> None:
153 if not msg["body"]:
154 return
156 if not msg.get_from().node:
157 return # ignore component and server messages
159 f = self._input_futures.pop(msg.get_from().bare, None)
160 if f is not None:
161 f.set_result(msg["body"])
162 return
164 c = msg["body"]
165 first_word, *rest = c.split(" ")
166 first_word = first_word.lower()
168 if first_word == "help":
169 return self._handle_help(msg, *rest)
171 if first_word in ("contact", "room"):
172 return await self._handle_recipient(first_word, msg, *rest)
174 mfrom = msg.get_from()
176 command = self._commands.get(first_word)
177 if command is None:
178 self._not_found(msg, first_word)
179 return
181 try:
182 session = command.raise_if_not_authorized(mfrom)
183 except XMPPError as e:
184 reply = msg.reply()
185 reply["body"] = e.text
186 reply.send()
187 raise
189 result: CommandResponseSessionType[Any] = await self.__wrap_handler(
190 msg, command.run, session, mfrom, *rest
191 )
192 self.xmpp.delivery_receipt.ack(msg)
193 await self._handle_result(result, msg, session)
195 def __make_uri(self, body: str) -> str:
196 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}"
198 async def _handle_result(
199 self,
200 result: CommandResponseSessionType[Any] | CommandResponseRecipientType[Any],
201 msg: Message,
202 session: "AnySession | None",
203 recipient: AnyRecipient | None = None,
204 ) -> CommandResponseSessionType[Any] | CommandResponseRecipientType[Any]:
205 if isinstance(result, str) or result is None:
206 reply = msg.reply()
207 reply["body"] = result or "End of command."
208 reply.send()
209 return None
211 if isinstance(result, Form):
212 if recipient is None:
213 result = cast(FormSession[AnySession], result)
214 else:
215 result = cast(FormRecipient[AnyRecipient], result)
216 try:
217 return await self.__handle_form( # type:ignore[return-value]
218 result, msg, session, recipient=recipient
219 )
220 except XMPPError as e:
221 if (
222 result.timeout_handler is None
223 or e.condition != "remote-server-timeout"
224 ):
225 raise e
226 return result.timeout_handler()
228 if isinstance(result, Confirmation):
229 yes_or_no = await self.input(msg.get_from(), result.prompt)
230 if not yes_or_no.lower().startswith("y"):
231 reply = msg.reply()
232 reply["body"] = "Canceled"
233 reply.send()
234 return None
235 if recipient is None:
236 result = cast(ConfirmationSession[AnySession], result)
237 result = await self.__wrap_handler(
238 msg,
239 result.handler,
240 session,
241 msg.get_from(),
242 *result.handler_args,
243 **result.handler_kwargs,
244 )
245 else:
246 result = cast(ConfirmationRecipient[AnyRecipient], result)
247 result = await self.__wrap_handler(
248 msg,
249 result.handler,
250 recipient,
251 *result.handler_args,
252 **result.handler_kwargs,
253 )
254 return await self._handle_result(result, msg, session, recipient=recipient)
256 if isinstance(result, TableResult):
257 if len(result.items) == 0:
258 msg.reply("Empty results").send()
259 return None
261 body = result.description + "\n"
262 for item in result.items:
263 for f in result.fields:
264 if f.type == "jid-single":
265 j = JID(item[f.var])
266 value = f"xmpp:{percent_encode(j)}"
267 if result.jids_are_mucs:
268 value += "?join"
269 else:
270 value = item[f.var] # type:ignore
271 body += f"\n{f.label or f.var}: {value}"
272 msg.reply(body).send()
273 return None
275 raise RuntimeError
277 async def __handle_form(
278 self,
279 result: Form,
280 msg: Message,
281 session: AnySession | None,
282 recipient: AnyRecipient | None = None,
283 ) -> CommandResponseType:
284 form_values = {}
285 for t in result.title, result.instructions:
286 if t:
287 msg.reply(t).send()
288 for f in result.fields:
289 if f.type == "fixed":
290 msg.reply(f"{f.label or f.var}: {f.value}").send()
291 else:
292 if f.type == "list-multi":
293 msg.reply(
294 "Multiple selection allowed, use new lines as a separator, ie, "
295 "one selected item per line. To select no item, reply with a space "
296 "(the punctuation)."
297 ).send()
298 if f.options:
299 for o in f.options:
300 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send()
301 if f.value:
302 msg.reply(f"Default: {f.value}").send()
303 if f.type == "boolean":
304 msg.reply("yes: " + self.__make_uri("yes")).send()
305 msg.reply("no: " + self.__make_uri("no")).send()
307 ans = await self.xmpp.input(
308 msg.get_from(),
309 (f.label or f.var) + "? (or 'abort')",
310 mtype="chat",
311 )
312 if ans.lower() == "abort":
313 return await self._handle_result("Command aborted", msg, session)
314 if f.type == "boolean":
315 if ans.lower() == "yes":
316 ans = "true"
317 else:
318 ans = "false"
320 if f.type.endswith("multi"):
321 choices = [] if ans == " " else ans.split("\n")
322 form_values[f.var] = f.validate(choices)
323 else:
324 form_values[f.var] = f.validate(ans)
325 if recipient is None:
326 new_result = await self.__wrap_handler(
327 msg,
328 result.handler,
329 form_values,
330 session,
331 msg.get_from(),
332 *result.handler_args,
333 **result.handler_kwargs,
334 )
335 new_result = cast(CommandResponseSessionType[Any], new_result)
336 else:
337 new_result = await self.__wrap_handler(
338 msg,
339 result.handler,
340 recipient,
341 form_values,
342 *result.handler_args,
343 **result.handler_kwargs,
344 )
345 new_result = cast(CommandResponseRecipientType[Any], new_result)
347 return await self._handle_result(new_result, msg, session, recipient=recipient)
349 @staticmethod
350 async def __wrap_handler(
351 msg: Message,
352 f: Callable[P, Awaitable[T] | T],
353 *a: P.args,
354 **k: P.kwargs,
355 ) -> T | None:
356 try:
357 if inspect.iscoroutinefunction(f):
358 return await f(*a, **k) # type:ignore[no-any-return]
359 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func):
360 return await f(*a, **k) # type:ignore[misc,no-any-return]
361 else:
362 return f(*a, **k) # type:ignore[return-value]
363 except Exception as e:
364 log.debug("Error in %s", f, exc_info=e)
365 reply = msg.reply()
366 reply["body"] = f"Error: {e}"
367 reply.send()
368 return None
370 def _handle_help(self, msg: Message, *rest: str) -> None:
371 if len(rest) == 0:
372 reply = msg.reply()
373 reply["body"] = self._help(msg.get_from())
374 reply.send()
375 elif len(rest) == 1 and (command := self._commands.get(rest[0])):
376 reply = msg.reply()
377 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}"
378 reply.send()
379 else:
380 self._not_found(msg, str(rest))
382 def _help(self, mfrom: JID) -> str:
383 session = self.xmpp.get_session_from_jid(mfrom)
385 msg = "Available commands:"
386 for c in sorted(
387 self._commands.values(),
388 key=lambda co: (
389 (
390 co.CATEGORY
391 if isinstance(co.CATEGORY, str)
392 else (
393 co.CATEGORY.name
394 if isinstance(co.CATEGORY, CommandCategory)
395 else ""
396 )
397 ),
398 co.CHAT_COMMAND,
399 ),
400 ):
401 try:
402 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session)
403 except XMPPError:
404 continue
405 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}"
406 return msg
408 def _not_found(self, msg: Message, word: str) -> Never:
409 e = self.UNKNOWN.format(word)
410 msg.reply(e).send()
411 raise XMPPError("item-not-found", e)
413 async def _handle_recipient(
414 self, recipient_str: Literal["contact", "room"], msg: Message, *args: str
415 ) -> None:
416 session = self.xmpp.get_session_from_jid(msg.get_from())
418 recipient_cls = LegacyContact if recipient_str == "contact" else LegacyMUC
420 if session is None:
421 raise XMPPError("subscription-required")
423 if len(args) == 0 or args[0] == "help":
424 self.xmpp.delivery_receipt.ack(msg)
425 self._help_recipient(msg, recipient_cls)
426 return
428 if len(args) == 1:
429 self._help_recipient(msg, recipient_cls)
430 raise XMPPError(
431 "bad-request",
432 f"Contact commands require at least two parameters: {recipient_str}_jid_username and command_name",
433 )
435 jid_username, command_name, *rest = args
437 command = recipient_cls.commands_chat.get(command_name)
438 if command is None:
439 raise XMPPError("item-not-found")
441 if recipient_cls is LegacyContact:
442 legacy_id = await session.contacts.jid_username_to_legacy_id(jid_username)
443 recipient = await session.contacts.by_legacy_id(legacy_id)
444 else:
445 legacy_id = await session.bookmarks.jid_username_to_legacy_id(jid_username)
446 recipient = await session.bookmarks.by_legacy_id(legacy_id)
448 result = await self.__wrap_handler(msg, command.run, recipient, *rest) # type:ignore[arg-type,func-returns-value]
449 self.xmpp.delivery_receipt.ack(msg)
450 await self._handle_result(result, msg, session, recipient)
452 def _help_recipient(
453 self, msg: Message, recipient_cls: type[AnyContact | AnyMUC]
454 ) -> None:
455 msg.reply(
456 "Available commands:\n"
457 + "\n".join(
458 f"{co.CHAT_COMMAND} ({co.NAME}): {co.HELP}"
459 for co in recipient_cls.commands_chat.values()
460 )
461 ).send()
464def percent_encode(jid: JID) -> str:
465 return f"{url_quote(jid.user)}@{jid.server}"
468log = logging.getLogger(__name__)