Coverage for slidge / command / chat_command.py: 71%

231 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import inspect 

7import logging 

8from collections.abc import Awaitable, Callable 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 Literal, 

13 Never, 

14 ParamSpec, 

15 TypeVar, 

16 cast, 

17 overload, 

18) 

19from urllib.parse import quote as url_quote 

20 

21from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

22from slixmpp.exceptions import XMPPError 

23from slixmpp.types import JidStr, MessageTypes 

24 

25from slidge.command.base import ( 

26 CommandResponseRecipientType, 

27 CommandResponseSessionType, 

28 ConfirmationRecipient, 

29 ConfirmationSession, 

30 FormRecipient, 

31 FormSession, 

32) 

33from slidge.contact import LegacyContact 

34from slidge.group import LegacyMUC 

35from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession 

36 

37from . import Command, CommandResponseType, Confirmation, Form, TableResult 

38from .categories import CommandCategory 

39 

40if TYPE_CHECKING: 

41 from ..core.gateway import BaseGateway 

42 

43T = TypeVar("T") 

44P = ParamSpec("P") 

45 

46 

47class ChatCommandProvider: 

48 UNKNOWN = "Wut? I don't know that command: {}" 

49 

50 def __init__(self, xmpp: "BaseGateway") -> None: 

51 self.xmpp = xmpp 

52 self._keywords = list[str]() 

53 self._commands: dict[str, Command[AnySession]] = {} 

54 self._input_futures = dict[str, asyncio.Future[str]]() 

55 self.xmpp.register_handler( 

56 CoroutineCallback( 

57 "chat_command_handler", 

58 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

59 self._handle_message, # type: ignore 

60 ) 

61 ) 

62 

63 def register(self, command: Command[AnySession]) -> None: 

64 """ 

65 Register a command to be used via chat messages with the gateway 

66 

67 Plugins should not call this, any class subclassing Command should be 

68 automatically added by slidge core. 

69 

70 :param command: the new command 

71 """ 

72 t = command.CHAT_COMMAND 

73 if t in self._commands: 

74 raise RuntimeError("There is already a command triggered by '%s'", t) 

75 self._commands[t] = command 

76 

77 @overload 

78 async def input(self, jid: JidStr, text: str | None = None) -> str: ... 

79 

80 @overload 

81 async def input( 

82 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ... 

83 ) -> asyncio.Future[str]: ... 

84 

85 @overload 

86 async def input( 

87 self, 

88 jid: JidStr, 

89 text: str | None = None, 

90 *, 

91 mtype: MessageTypes = "chat", 

92 timeout: int = 60, 

93 blocking: Literal[True] = True, 

94 **msg_kwargs: Any, # noqa:ANN401 

95 ) -> str: ... 

96 

97 async def input( 

98 self, 

99 jid: JidStr, 

100 text: str | None = None, 

101 *, 

102 mtype: MessageTypes = "chat", 

103 timeout: int = 60, 

104 blocking: bool = True, 

105 **msg_kwargs: Any, 

106 ) -> str | asyncio.Future[str]: 

107 """ 

108 Request arbitrary user input using a simple chat message, and await the result. 

109 

110 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

111 to directly target a user. 

112 

113 NB: When using this, the next message that the user sent to the component will 

114 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

115 Await the coroutine to get its content. 

116 

117 :param jid: The JID we want input from 

118 :param text: A prompt to display for the user 

119 :param mtype: Message type 

120 :param timeout: 

121 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

122 is returned instead of a str 

123 :return: The user's reply 

124 """ 

125 jid = JID(jid) 

126 if text is not None: 

127 self.xmpp.send_message( 

128 mto=jid, 

129 mbody=text, 

130 mtype=mtype, 

131 mfrom=self.xmpp.boundjid.bare, 

132 **msg_kwargs, 

133 ) 

134 f: asyncio.Future[str] = asyncio.get_event_loop().create_future() 

135 self._input_futures[jid.bare] = f 

136 if not blocking: 

137 return f 

138 try: 

139 await asyncio.wait_for(f, timeout) 

140 except TimeoutError: 

141 self.xmpp.send_message( 

142 mto=jid, 

143 mbody="You took too much time to reply", 

144 mtype=mtype, 

145 mfrom=self.xmpp.boundjid.bare, 

146 ) 

147 del self._input_futures[jid.bare] 

148 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

149 

150 return f.result() 

151 

152 async def _handle_message(self, msg: Message) -> None: 

153 if not msg["body"]: 

154 return 

155 

156 if not msg.get_from().node: 

157 return # ignore component and server messages 

158 

159 f = self._input_futures.pop(msg.get_from().bare, None) 

160 if f is not None: 

161 f.set_result(msg["body"]) 

162 return 

163 

164 c = msg["body"] 

165 first_word, *rest = c.split(" ") 

166 first_word = first_word.lower() 

167 

168 if first_word == "help": 

169 return self._handle_help(msg, *rest) 

170 

171 if first_word in ("contact", "room"): 

172 return await self._handle_recipient(first_word, msg, *rest) 

173 

174 mfrom = msg.get_from() 

175 

176 command = self._commands.get(first_word) 

177 if command is None: 

178 self._not_found(msg, first_word) 

179 return 

180 

181 try: 

182 session = command.raise_if_not_authorized(mfrom) 

183 except XMPPError as e: 

184 reply = msg.reply() 

185 reply["body"] = e.text 

186 reply.send() 

187 raise 

188 

189 result: CommandResponseSessionType[Any] = await self.__wrap_handler( 

190 msg, command.run, session, mfrom, *rest 

191 ) 

192 self.xmpp.delivery_receipt.ack(msg) 

193 await self._handle_result(result, msg, session) 

194 

195 def __make_uri(self, body: str) -> str: 

196 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

197 

198 async def _handle_result( 

199 self, 

200 result: CommandResponseSessionType[Any] | CommandResponseRecipientType[Any], 

201 msg: Message, 

202 session: "AnySession | None", 

203 recipient: AnyRecipient | None = None, 

204 ) -> CommandResponseSessionType[Any] | CommandResponseRecipientType[Any]: 

205 if isinstance(result, str) or result is None: 

206 reply = msg.reply() 

207 reply["body"] = result or "End of command." 

208 reply.send() 

209 return None 

210 

211 if isinstance(result, Form): 

212 if recipient is None: 

213 result = cast(FormSession[AnySession], result) 

214 else: 

215 result = cast(FormRecipient[AnyRecipient], result) 

216 try: 

217 return await self.__handle_form( # type:ignore[return-value] 

218 result, msg, session, recipient=recipient 

219 ) 

220 except XMPPError as e: 

221 if ( 

222 result.timeout_handler is None 

223 or e.condition != "remote-server-timeout" 

224 ): 

225 raise e 

226 return result.timeout_handler() 

227 

228 if isinstance(result, Confirmation): 

229 yes_or_no = await self.input(msg.get_from(), result.prompt) 

230 if not yes_or_no.lower().startswith("y"): 

231 reply = msg.reply() 

232 reply["body"] = "Canceled" 

233 reply.send() 

234 return None 

235 if recipient is None: 

236 result = cast(ConfirmationSession[AnySession], result) 

237 result = await self.__wrap_handler( 

238 msg, 

239 result.handler, 

240 session, 

241 msg.get_from(), 

242 *result.handler_args, 

243 **result.handler_kwargs, 

244 ) 

245 else: 

246 result = cast(ConfirmationRecipient[AnyRecipient], result) 

247 result = await self.__wrap_handler( 

248 msg, 

249 result.handler, 

250 recipient, 

251 *result.handler_args, 

252 **result.handler_kwargs, 

253 ) 

254 return await self._handle_result(result, msg, session, recipient=recipient) 

255 

256 if isinstance(result, TableResult): 

257 if len(result.items) == 0: 

258 msg.reply("Empty results").send() 

259 return None 

260 

261 body = result.description + "\n" 

262 for item in result.items: 

263 for f in result.fields: 

264 if f.type == "jid-single": 

265 j = JID(item[f.var]) 

266 value = f"xmpp:{percent_encode(j)}" 

267 if result.jids_are_mucs: 

268 value += "?join" 

269 else: 

270 value = item[f.var] # type:ignore 

271 body += f"\n{f.label or f.var}: {value}" 

272 msg.reply(body).send() 

273 

274 raise RuntimeError 

275 

276 async def __handle_form( 

277 self, 

278 result: Form, 

279 msg: Message, 

280 session: AnySession | None, 

281 recipient: AnyRecipient | None = None, 

282 ) -> CommandResponseType: 

283 form_values = {} 

284 for t in result.title, result.instructions: 

285 if t: 

286 msg.reply(t).send() 

287 for f in result.fields: 

288 if f.type == "fixed": 

289 msg.reply(f"{f.label or f.var}: {f.value}").send() 

290 else: 

291 if f.type == "list-multi": 

292 msg.reply( 

293 "Multiple selection allowed, use new lines as a separator, ie, " 

294 "one selected item per line. To select no item, reply with a space " 

295 "(the punctuation)." 

296 ).send() 

297 if f.options: 

298 for o in f.options: 

299 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send() 

300 if f.value: 

301 msg.reply(f"Default: {f.value}").send() 

302 if f.type == "boolean": 

303 msg.reply("yes: " + self.__make_uri("yes")).send() 

304 msg.reply("no: " + self.__make_uri("no")).send() 

305 

306 ans = await self.xmpp.input( 

307 msg.get_from(), 

308 (f.label or f.var) + "? (or 'abort')", 

309 mtype="chat", 

310 ) 

311 if ans.lower() == "abort": 

312 return await self._handle_result("Command aborted", msg, session) 

313 if f.type == "boolean": 

314 if ans.lower() == "yes": 

315 ans = "true" 

316 else: 

317 ans = "false" 

318 

319 if f.type.endswith("multi"): 

320 choices = [] if ans == " " else ans.split("\n") 

321 form_values[f.var] = f.validate(choices) 

322 else: 

323 form_values[f.var] = f.validate(ans) 

324 if recipient is None: 

325 new_result = await self.__wrap_handler( 

326 msg, 

327 result.handler, 

328 form_values, 

329 session, 

330 msg.get_from(), 

331 *result.handler_args, 

332 **result.handler_kwargs, 

333 ) 

334 new_result = cast(CommandResponseSessionType[Any], new_result) 

335 else: 

336 new_result = await self.__wrap_handler( 

337 msg, 

338 result.handler, 

339 recipient, 

340 form_values, 

341 *result.handler_args, 

342 **result.handler_kwargs, 

343 ) 

344 new_result = cast(CommandResponseRecipientType[Any], new_result) 

345 

346 return await self._handle_result(new_result, msg, session, recipient=recipient) 

347 

348 @staticmethod 

349 async def __wrap_handler( 

350 msg: Message, 

351 f: Callable[P, Awaitable[T] | T], 

352 *a: P.args, 

353 **k: P.kwargs, 

354 ) -> T | None: 

355 try: 

356 if inspect.iscoroutinefunction(f): 

357 return await f(*a, **k) # type:ignore[no-any-return] 

358 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

359 return await f(*a, **k) # type:ignore[misc,no-any-return] 

360 else: 

361 return f(*a, **k) # type:ignore[return-value] 

362 except Exception as e: 

363 log.debug("Error in %s", f, exc_info=e) 

364 reply = msg.reply() 

365 reply["body"] = f"Error: {e}" 

366 reply.send() 

367 return None 

368 

369 def _handle_help(self, msg: Message, *rest: str) -> None: 

370 if len(rest) == 0: 

371 reply = msg.reply() 

372 reply["body"] = self._help(msg.get_from()) 

373 reply.send() 

374 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

375 reply = msg.reply() 

376 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

377 reply.send() 

378 else: 

379 self._not_found(msg, str(rest)) 

380 

381 def _help(self, mfrom: JID) -> str: 

382 session = self.xmpp.get_session_from_jid(mfrom) 

383 

384 msg = "Available commands:" 

385 for c in sorted( 

386 self._commands.values(), 

387 key=lambda co: ( 

388 ( 

389 co.CATEGORY 

390 if isinstance(co.CATEGORY, str) 

391 else ( 

392 co.CATEGORY.name 

393 if isinstance(co.CATEGORY, CommandCategory) 

394 else "" 

395 ) 

396 ), 

397 co.CHAT_COMMAND, 

398 ), 

399 ): 

400 try: 

401 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session) 

402 except XMPPError: 

403 continue 

404 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

405 return msg 

406 

407 def _not_found(self, msg: Message, word: str) -> Never: 

408 e = self.UNKNOWN.format(word) 

409 msg.reply(e).send() 

410 raise XMPPError("item-not-found", e) 

411 

412 async def _handle_recipient( 

413 self, recipient_str: Literal["contact", "room"], msg: Message, *args: str 

414 ) -> None: 

415 session = self.xmpp.get_session_from_jid(msg.get_from()) 

416 

417 recipient_cls = LegacyContact if recipient_str == "contact" else LegacyMUC 

418 

419 if session is None: 

420 raise XMPPError("subscription-required") 

421 

422 if len(args) == 0 or args[0] == "help": 

423 self.xmpp.delivery_receipt.ack(msg) 

424 self._help_recipient(msg, recipient_cls) 

425 return 

426 

427 if len(args) == 1: 

428 self._help_recipient(msg, recipient_cls) 

429 raise XMPPError( 

430 "bad-request", 

431 f"Contact commands require at least two parameters: {recipient_str}_jid_username and command_name", 

432 ) 

433 

434 jid_username, command_name, *rest = args 

435 

436 command = recipient_cls.commands_chat.get(command_name) 

437 if command is None: 

438 raise XMPPError("item-not-found") 

439 

440 if recipient_cls is LegacyContact: 

441 legacy_id = await session.contacts.jid_username_to_legacy_id(jid_username) 

442 recipient = await session.contacts.by_legacy_id(legacy_id) 

443 else: 

444 legacy_id = await session.bookmarks.jid_username_to_legacy_id(jid_username) 

445 recipient = await session.bookmarks.by_legacy_id(legacy_id) 

446 

447 result = await self.__wrap_handler(msg, command.run, recipient, *rest) # type:ignore[arg-type,func-returns-value] 

448 self.xmpp.delivery_receipt.ack(msg) 

449 await self._handle_result(result, msg, session, recipient) 

450 

451 def _help_recipient( 

452 self, msg: Message, recipient_cls: type[AnyContact | AnyMUC] 

453 ) -> None: 

454 msg.reply( 

455 "Available commands:\n" 

456 + "\n".join( 

457 f"{co.CHAT_COMMAND} ({co.NAME}): {co.HELP}" 

458 for co in recipient_cls.commands_chat.values() 

459 ) 

460 ).send() 

461 

462 

463def percent_encode(jid: JID) -> str: 

464 return f"{url_quote(jid.user)}@{jid.server}" 

465 

466 

467log = logging.getLogger(__name__)