Coverage for slidge / command / chat_command.py: 71%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-27 20:49 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import inspect 

7import logging 

8from collections.abc import Awaitable, Callable 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 Literal, 

13 Never, 

14 ParamSpec, 

15 TypeVar, 

16 cast, 

17 overload, 

18) 

19from urllib.parse import quote as url_quote 

20 

21from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

22from slixmpp.exceptions import XMPPError 

23from slixmpp.types import JidStr, MessageTypes 

24 

25from slidge.command.base import ( 

26 CommandResponseRecipientType, 

27 CommandResponseSessionType, 

28 ConfirmationRecipient, 

29 ConfirmationSession, 

30 FormRecipient, 

31 FormSession, 

32) 

33from slidge.contact import LegacyContact 

34from slidge.group import LegacyMUC 

35from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession 

36 

37from . import Command, CommandResponseType, Confirmation, Form, TableResult 

38from .categories import CommandCategory 

39 

40if TYPE_CHECKING: 

41 from ..core.gateway import BaseGateway 

42 

43T = TypeVar("T") 

44P = ParamSpec("P") 

45 

46 

47class ChatCommandProvider: 

48 UNKNOWN = "Wut? I don't know that command: {}" 

49 

50 def __init__(self, xmpp: "BaseGateway") -> None: 

51 self.xmpp = xmpp 

52 self._keywords = list[str]() 

53 self._commands: dict[str, Command[AnySession]] = {} 

54 self._input_futures = dict[str, asyncio.Future[str]]() 

55 self.xmpp.register_handler( 

56 CoroutineCallback( 

57 "chat_command_handler", 

58 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

59 self._handle_message, # type: ignore 

60 ) 

61 ) 

62 

63 def register(self, command: Command[AnySession]) -> None: 

64 """ 

65 Register a command to be used via chat messages with the gateway 

66 

67 Plugins should not call this, any class subclassing Command should be 

68 automatically added by slidge core. 

69 

70 :param command: the new command 

71 """ 

72 t = command.CHAT_COMMAND 

73 if t in self._commands: 

74 raise RuntimeError("There is already a command triggered by '%s'", t) 

75 self._commands[t] = command 

76 

77 @overload 

78 async def input(self, jid: JidStr, text: str | None = None) -> str: ... 

79 

80 @overload 

81 async def input( 

82 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ... 

83 ) -> asyncio.Future[str]: ... 

84 

85 @overload 

86 async def input( 

87 self, 

88 jid: JidStr, 

89 text: str | None = None, 

90 *, 

91 mtype: MessageTypes = "chat", 

92 timeout: int = 60, 

93 blocking: Literal[True] = True, 

94 **msg_kwargs: Any, # noqa:ANN401 

95 ) -> str: ... 

96 

97 async def input( 

98 self, 

99 jid: JidStr, 

100 text: str | None = None, 

101 *, 

102 mtype: MessageTypes = "chat", 

103 timeout: int = 60, 

104 blocking: bool = True, 

105 **msg_kwargs: Any, 

106 ) -> str | asyncio.Future[str]: 

107 """ 

108 Request arbitrary user input using a simple chat message, and await the result. 

109 

110 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

111 to directly target a user. 

112 

113 NB: When using this, the next message that the user sent to the component will 

114 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

115 Await the coroutine to get its content. 

116 

117 :param jid: The JID we want input from 

118 :param text: A prompt to display for the user 

119 :param mtype: Message type 

120 :param timeout: 

121 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

122 is returned instead of a str 

123 :return: The user's reply 

124 """ 

125 jid = JID(jid) 

126 if text is not None: 

127 self.xmpp.send_message( 

128 mto=jid, 

129 mbody=text, 

130 mtype=mtype, 

131 mfrom=self.xmpp.boundjid.bare, 

132 **msg_kwargs, 

133 ) 

134 f: asyncio.Future[str] = asyncio.get_event_loop().create_future() 

135 self._input_futures[jid.bare] = f 

136 if not blocking: 

137 return f 

138 try: 

139 await asyncio.wait_for(f, timeout) 

140 except TimeoutError: 

141 self.xmpp.send_message( 

142 mto=jid, 

143 mbody="You took too much time to reply", 

144 mtype=mtype, 

145 mfrom=self.xmpp.boundjid.bare, 

146 ) 

147 del self._input_futures[jid.bare] 

148 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

149 

150 return f.result() 

151 

152 async def _handle_message(self, msg: Message) -> None: 

153 if not msg["body"]: 

154 return 

155 

156 if not msg.get_from().node: 

157 return # ignore component and server messages 

158 

159 f = self._input_futures.pop(msg.get_from().bare, None) 

160 if f is not None: 

161 f.set_result(msg["body"]) 

162 return 

163 

164 c = msg["body"] 

165 first_word, *rest = c.split(" ") 

166 first_word = first_word.lower() 

167 

168 if first_word == "help": 

169 return self._handle_help(msg, *rest) 

170 

171 if first_word in ("contact", "room"): 

172 return await self._handle_recipient(first_word, msg, *rest) 

173 

174 mfrom = msg.get_from() 

175 

176 command = self._commands.get(first_word) 

177 if command is None: 

178 self._not_found(msg, first_word) 

179 return 

180 

181 try: 

182 session = command.raise_if_not_authorized(mfrom) 

183 except XMPPError as e: 

184 reply = msg.reply() 

185 reply["body"] = e.text 

186 reply.send() 

187 raise 

188 

189 result: CommandResponseSessionType[Any] = await self.__wrap_handler( 

190 msg, command.run, session, mfrom, *rest 

191 ) 

192 self.xmpp.delivery_receipt.ack(msg) 

193 await self._handle_result(result, msg, session) 

194 

195 def __make_uri(self, body: str) -> str: 

196 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

197 

198 async def _handle_result( 

199 self, 

200 result: CommandResponseSessionType[Any] | CommandResponseRecipientType[Any], 

201 msg: Message, 

202 session: "AnySession | None", 

203 recipient: AnyRecipient | None = None, 

204 ) -> CommandResponseSessionType[Any] | CommandResponseRecipientType[Any]: 

205 if isinstance(result, str) or result is None: 

206 reply = msg.reply() 

207 reply["body"] = result or "End of command." 

208 reply.send() 

209 return None 

210 

211 if isinstance(result, Form): 

212 if recipient is None: 

213 result = cast(FormSession[AnySession], result) 

214 else: 

215 result = cast(FormRecipient[AnyRecipient], result) 

216 try: 

217 return await self.__handle_form( # type:ignore[return-value] 

218 result, msg, session, recipient=recipient 

219 ) 

220 except XMPPError as e: 

221 if ( 

222 result.timeout_handler is None 

223 or e.condition != "remote-server-timeout" 

224 ): 

225 raise e 

226 return result.timeout_handler() 

227 

228 if isinstance(result, Confirmation): 

229 yes_or_no = await self.input(msg.get_from(), result.prompt) 

230 if not yes_or_no.lower().startswith("y"): 

231 reply = msg.reply() 

232 reply["body"] = "Canceled" 

233 reply.send() 

234 return None 

235 if recipient is None: 

236 result = cast(ConfirmationSession[AnySession], result) 

237 result = await self.__wrap_handler( 

238 msg, 

239 result.handler, 

240 session, 

241 msg.get_from(), 

242 *result.handler_args, 

243 **result.handler_kwargs, 

244 ) 

245 else: 

246 result = cast(ConfirmationRecipient[AnyRecipient], result) 

247 result = await self.__wrap_handler( 

248 msg, 

249 result.handler, 

250 recipient, 

251 *result.handler_args, 

252 **result.handler_kwargs, 

253 ) 

254 return await self._handle_result(result, msg, session, recipient=recipient) 

255 

256 if isinstance(result, TableResult): 

257 if len(result.items) == 0: 

258 msg.reply("Empty results").send() 

259 return None 

260 

261 body = result.description + "\n" 

262 for item in result.items: 

263 for f in result.fields: 

264 if f.type == "jid-single": 

265 j = JID(item[f.var]) 

266 value = f"xmpp:{percent_encode(j)}" 

267 if result.jids_are_mucs: 

268 value += "?join" 

269 else: 

270 value = item[f.var] # type:ignore 

271 body += f"\n{f.label or f.var}: {value}" 

272 msg.reply(body).send() 

273 return None 

274 

275 raise RuntimeError 

276 

277 async def __handle_form( 

278 self, 

279 result: Form, 

280 msg: Message, 

281 session: AnySession | None, 

282 recipient: AnyRecipient | None = None, 

283 ) -> CommandResponseType: 

284 form_values = {} 

285 for t in result.title, result.instructions: 

286 if t: 

287 msg.reply(t).send() 

288 for f in result.fields: 

289 if f.type == "fixed": 

290 msg.reply(f"{f.label or f.var}: {f.value}").send() 

291 else: 

292 if f.type == "list-multi": 

293 msg.reply( 

294 "Multiple selection allowed, use new lines as a separator, ie, " 

295 "one selected item per line. To select no item, reply with a space " 

296 "(the punctuation)." 

297 ).send() 

298 if f.options: 

299 for o in f.options: 

300 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send() 

301 if f.value: 

302 msg.reply(f"Default: {f.value}").send() 

303 if f.type == "boolean": 

304 msg.reply("yes: " + self.__make_uri("yes")).send() 

305 msg.reply("no: " + self.__make_uri("no")).send() 

306 

307 ans = await self.xmpp.input( 

308 msg.get_from(), 

309 (f.label or f.var) + "? (or 'abort')", 

310 mtype="chat", 

311 ) 

312 if ans.lower() == "abort": 

313 return await self._handle_result("Command aborted", msg, session) 

314 if f.type == "boolean": 

315 if ans.lower() == "yes": 

316 ans = "true" 

317 else: 

318 ans = "false" 

319 

320 if f.type.endswith("multi"): 

321 choices = [] if ans == " " else ans.split("\n") 

322 form_values[f.var] = f.validate(choices) 

323 else: 

324 form_values[f.var] = f.validate(ans) 

325 if recipient is None: 

326 new_result = await self.__wrap_handler( 

327 msg, 

328 result.handler, 

329 form_values, 

330 session, 

331 msg.get_from(), 

332 *result.handler_args, 

333 **result.handler_kwargs, 

334 ) 

335 new_result = cast(CommandResponseSessionType[Any], new_result) 

336 else: 

337 new_result = await self.__wrap_handler( 

338 msg, 

339 result.handler, 

340 recipient, 

341 form_values, 

342 *result.handler_args, 

343 **result.handler_kwargs, 

344 ) 

345 new_result = cast(CommandResponseRecipientType[Any], new_result) 

346 

347 return await self._handle_result(new_result, msg, session, recipient=recipient) 

348 

349 @staticmethod 

350 async def __wrap_handler( 

351 msg: Message, 

352 f: Callable[P, Awaitable[T] | T], 

353 *a: P.args, 

354 **k: P.kwargs, 

355 ) -> T | None: 

356 try: 

357 if inspect.iscoroutinefunction(f): 

358 return await f(*a, **k) # type:ignore[no-any-return] 

359 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

360 return await f(*a, **k) # type:ignore[misc,no-any-return] 

361 else: 

362 return f(*a, **k) # type:ignore[return-value] 

363 except Exception as e: 

364 log.debug("Error in %s", f, exc_info=e) 

365 reply = msg.reply() 

366 reply["body"] = f"Error: {e}" 

367 reply.send() 

368 return None 

369 

370 def _handle_help(self, msg: Message, *rest: str) -> None: 

371 if len(rest) == 0: 

372 reply = msg.reply() 

373 reply["body"] = self._help(msg.get_from()) 

374 reply.send() 

375 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

376 reply = msg.reply() 

377 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

378 reply.send() 

379 else: 

380 self._not_found(msg, str(rest)) 

381 

382 def _help(self, mfrom: JID) -> str: 

383 session = self.xmpp.get_session_from_jid(mfrom) 

384 

385 msg = "Available commands:" 

386 for c in sorted( 

387 self._commands.values(), 

388 key=lambda co: ( 

389 ( 

390 co.CATEGORY 

391 if isinstance(co.CATEGORY, str) 

392 else ( 

393 co.CATEGORY.name 

394 if isinstance(co.CATEGORY, CommandCategory) 

395 else "" 

396 ) 

397 ), 

398 co.CHAT_COMMAND, 

399 ), 

400 ): 

401 try: 

402 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session) 

403 except XMPPError: 

404 continue 

405 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

406 return msg 

407 

408 def _not_found(self, msg: Message, word: str) -> Never: 

409 e = self.UNKNOWN.format(word) 

410 msg.reply(e).send() 

411 raise XMPPError("item-not-found", e) 

412 

413 async def _handle_recipient( 

414 self, recipient_str: Literal["contact", "room"], msg: Message, *args: str 

415 ) -> None: 

416 session = self.xmpp.get_session_from_jid(msg.get_from()) 

417 

418 recipient_cls = LegacyContact if recipient_str == "contact" else LegacyMUC 

419 

420 if session is None: 

421 raise XMPPError("subscription-required") 

422 

423 if len(args) == 0 or args[0] == "help": 

424 self.xmpp.delivery_receipt.ack(msg) 

425 self._help_recipient(msg, recipient_cls) 

426 return 

427 

428 if len(args) == 1: 

429 self._help_recipient(msg, recipient_cls) 

430 raise XMPPError( 

431 "bad-request", 

432 f"Contact commands require at least two parameters: {recipient_str}_jid_username and command_name", 

433 ) 

434 

435 jid_username, command_name, *rest = args 

436 

437 command = recipient_cls.commands_chat.get(command_name) 

438 if command is None: 

439 raise XMPPError("item-not-found") 

440 

441 if recipient_cls is LegacyContact: 

442 legacy_id = await session.contacts.jid_username_to_legacy_id(jid_username) 

443 recipient = await session.contacts.by_legacy_id(legacy_id) 

444 else: 

445 legacy_id = await session.bookmarks.jid_username_to_legacy_id(jid_username) 

446 recipient = await session.bookmarks.by_legacy_id(legacy_id) 

447 

448 result = await self.__wrap_handler(msg, command.run, recipient, *rest) # type:ignore[arg-type,func-returns-value] 

449 self.xmpp.delivery_receipt.ack(msg) 

450 await self._handle_result(result, msg, session, recipient) 

451 

452 def _help_recipient( 

453 self, msg: Message, recipient_cls: type[AnyContact | AnyMUC] 

454 ) -> None: 

455 msg.reply( 

456 "Available commands:\n" 

457 + "\n".join( 

458 f"{co.CHAT_COMMAND} ({co.NAME}): {co.HELP}" 

459 for co in recipient_cls.commands_chat.values() 

460 ) 

461 ).send() 

462 

463 

464def percent_encode(jid: JID) -> str: 

465 return f"{url_quote(jid.user)}@{jid.server}" 

466 

467 

468log = logging.getLogger(__name__)