Coverage for slidge / command / chat_command.py: 72%

232 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-20 19:56 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import inspect 

7import logging 

8from collections.abc import Awaitable, Callable 

9from typing import ( 

10 TYPE_CHECKING, 

11 Any, 

12 Generic, 

13 Literal, 

14 Never, 

15 ParamSpec, 

16 TypeVar, 

17 cast, 

18 overload, 

19) 

20from urllib.parse import quote as url_quote 

21 

22from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

23from slixmpp.exceptions import XMPPError 

24from slixmpp.types import JidStr, MessageTypes 

25 

26from slidge.command.base import ( 

27 CommandResponseRecipientType, 

28 CommandResponseSessionType, 

29 ConfirmationRecipient, 

30 ConfirmationSession, 

31 FormRecipient, 

32 FormSession, 

33) 

34from slidge.contact import LegacyContact 

35from slidge.group import LegacyMUC 

36from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession 

37 

38from . import Command, CommandResponseType, Confirmation, Form, TableResult 

39from .categories import CommandCategory 

40 

41if TYPE_CHECKING: 

42 from ..core.gateway import BaseGateway 

43 

44GatewayType = TypeVar("GatewayType", bound="BaseGateway[Any]") 

45T = TypeVar("T") 

46P = ParamSpec("P") 

47 

48 

49class ChatCommandProvider(Generic[GatewayType]): 

50 UNKNOWN = "Wut? I don't know that command: {}" 

51 xmpp: GatewayType 

52 

53 def __init__(self, xmpp: GatewayType) -> None: 

54 self.xmpp = xmpp 

55 self._keywords = list[str]() 

56 self._commands: dict[str, Command[AnySession]] = {} 

57 self._input_futures = dict[str, asyncio.Future[str]]() 

58 self.xmpp.register_handler( 

59 CoroutineCallback( 

60 "chat_command_handler", 

61 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

62 self._handle_message, # type: ignore 

63 ) 

64 ) 

65 

66 def register(self, command: Command[AnySession]) -> None: 

67 """ 

68 Register a command to be used via chat messages with the gateway 

69 

70 Plugins should not call this, any class subclassing Command should be 

71 automatically added by slidge core. 

72 

73 :param command: the new command 

74 """ 

75 t = command.CHAT_COMMAND 

76 if t in self._commands: 

77 raise RuntimeError("There is already a command triggered by '%s'", t) 

78 self._commands[t] = command 

79 

80 @overload 

81 async def input(self, jid: JidStr, text: str | None = None) -> str: ... 

82 

83 @overload 

84 async def input( 

85 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ... 

86 ) -> asyncio.Future[str]: ... 

87 

88 @overload 

89 async def input( 

90 self, 

91 jid: JidStr, 

92 text: str | None = None, 

93 *, 

94 mtype: MessageTypes = "chat", 

95 timeout: int = 60, 

96 blocking: Literal[True] = True, 

97 **msg_kwargs: Any, # noqa:ANN401 

98 ) -> str: ... 

99 

100 async def input( 

101 self, 

102 jid: JidStr, 

103 text: str | None = None, 

104 *, 

105 mtype: MessageTypes = "chat", 

106 timeout: int = 60, 

107 blocking: bool = True, 

108 **msg_kwargs: Any, 

109 ) -> str | asyncio.Future[str]: 

110 """ 

111 Request arbitrary user input using a simple chat message, and await the result. 

112 

113 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

114 to directly target a user. 

115 

116 NB: When using this, the next message that the user sent to the component will 

117 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

118 Await the coroutine to get its content. 

119 

120 :param jid: The JID we want input from 

121 :param text: A prompt to display for the user 

122 :param mtype: Message type 

123 :param timeout: 

124 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

125 is returned instead of a str 

126 :return: The user's reply 

127 """ 

128 jid = JID(jid) 

129 if text is not None: 

130 self.xmpp.send_message( 

131 mto=jid, 

132 mbody=text, 

133 mtype=mtype, 

134 mfrom=self.xmpp.boundjid.bare, 

135 **msg_kwargs, 

136 ) 

137 f: asyncio.Future[str] = asyncio.get_event_loop().create_future() 

138 self._input_futures[jid.bare] = f 

139 if not blocking: 

140 return f 

141 try: 

142 await asyncio.wait_for(f, timeout) 

143 except TimeoutError: 

144 self.xmpp.send_message( 

145 mto=jid, 

146 mbody="You took too much time to reply", 

147 mtype=mtype, 

148 mfrom=self.xmpp.boundjid.bare, 

149 ) 

150 del self._input_futures[jid.bare] 

151 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

152 

153 return f.result() 

154 

155 async def _handle_message(self, msg: Message) -> None: 

156 if not msg["body"]: 

157 return 

158 

159 if not msg.get_from().node: 

160 return # ignore component and server messages 

161 

162 f = self._input_futures.pop(msg.get_from().bare, None) 

163 if f is not None: 

164 f.set_result(msg["body"]) 

165 return 

166 

167 c = msg["body"] 

168 first_word, *rest = c.split(" ") 

169 first_word = first_word.lower() 

170 

171 if first_word == "help": 

172 return self._handle_help(msg, *rest) 

173 

174 if first_word in ("contact", "room"): 

175 return await self._handle_recipient(first_word, msg, *rest) 

176 

177 mfrom = msg.get_from() 

178 

179 command = self._commands.get(first_word) 

180 if command is None: 

181 self._not_found(msg, first_word) 

182 return 

183 

184 try: 

185 session = command.raise_if_not_authorized(mfrom) 

186 except XMPPError as e: 

187 reply = msg.reply() 

188 reply["body"] = e.text 

189 reply.send() 

190 raise 

191 

192 result: CommandResponseSessionType[Any] = await self.__wrap_handler( 

193 msg, command.run, session, mfrom, *rest 

194 ) 

195 self.xmpp.delivery_receipt.ack(msg) 

196 await self._handle_result(result, msg, session) 

197 

198 def __make_uri(self, body: str) -> str: 

199 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

200 

201 async def _handle_result( 

202 self, 

203 result: CommandResponseSessionType[Any] | CommandResponseRecipientType[Any], 

204 msg: Message, 

205 session: "AnySession | None", 

206 recipient: AnyRecipient | None = None, 

207 ) -> CommandResponseSessionType[Any] | CommandResponseRecipientType[Any]: 

208 if isinstance(result, str) or result is None: 

209 reply = msg.reply() 

210 reply["body"] = result or "End of command." 

211 reply.send() 

212 return None 

213 

214 if isinstance(result, Form): 

215 if recipient is None: 

216 result = cast(FormSession[AnySession], result) 

217 else: 

218 result = cast(FormRecipient[AnyRecipient], result) 

219 try: 

220 return await self.__handle_form( # type:ignore[return-value] 

221 result, msg, session, recipient=recipient 

222 ) 

223 except XMPPError as e: 

224 if ( 

225 result.timeout_handler is None 

226 or e.condition != "remote-server-timeout" 

227 ): 

228 raise e 

229 return result.timeout_handler() 

230 

231 if isinstance(result, Confirmation): 

232 yes_or_no = await self.input(msg.get_from(), result.prompt) 

233 if not yes_or_no.lower().startswith("y"): 

234 reply = msg.reply() 

235 reply["body"] = "Canceled" 

236 reply.send() 

237 return None 

238 if recipient is None: 

239 result = cast(ConfirmationSession[AnySession], result) 

240 result = await self.__wrap_handler( 

241 msg, 

242 result.handler, 

243 session, 

244 msg.get_from(), 

245 *result.handler_args, 

246 **result.handler_kwargs, 

247 ) 

248 else: 

249 result = cast(ConfirmationRecipient[AnyRecipient], result) 

250 result = await self.__wrap_handler( 

251 msg, 

252 result.handler, 

253 recipient, 

254 *result.handler_args, 

255 **result.handler_kwargs, 

256 ) 

257 return await self._handle_result(result, msg, session, recipient=recipient) 

258 

259 if isinstance(result, TableResult): 

260 if len(result.items) == 0: 

261 msg.reply("Empty results").send() 

262 return None 

263 

264 body = result.description + "\n" 

265 for item in result.items: 

266 for f in result.fields: 

267 if f.type == "jid-single": 

268 j = JID(item[f.var]) 

269 value = f"xmpp:{percent_encode(j)}" 

270 if result.jids_are_mucs: 

271 value += "?join" 

272 else: 

273 value = item[f.var] # type:ignore 

274 body += f"\n{f.label or f.var}: {value}" 

275 msg.reply(body).send() 

276 return None 

277 

278 raise RuntimeError 

279 

280 async def __handle_form( 

281 self, 

282 result: Form, 

283 msg: Message, 

284 session: AnySession | None, 

285 recipient: AnyRecipient | None = None, 

286 ) -> CommandResponseType: 

287 form_values = {} 

288 for t in result.title, result.instructions: 

289 if t: 

290 msg.reply(t).send() 

291 for f in result.fields: 

292 if f.type == "fixed": 

293 msg.reply(f"{f.label or f.var}: {f.value}").send() 

294 else: 

295 if f.type == "list-multi": 

296 msg.reply( 

297 "Multiple selection allowed, use new lines as a separator, ie, " 

298 "one selected item per line. To select no item, reply with a space " 

299 "(the punctuation)." 

300 ).send() 

301 if f.options: 

302 for o in f.options: 

303 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send() 

304 if f.value: 

305 msg.reply(f"Default: {f.value}").send() 

306 if f.type == "boolean": 

307 msg.reply("yes: " + self.__make_uri("yes")).send() 

308 msg.reply("no: " + self.__make_uri("no")).send() 

309 

310 ans = await self.xmpp.input( 

311 msg.get_from(), 

312 (f.label or f.var) + "? (or 'abort')", 

313 mtype="chat", 

314 ) 

315 if ans.lower() == "abort": 

316 return await self._handle_result("Command aborted", msg, session) 

317 if f.type == "boolean": 

318 ans = "true" if ans.lower() == "yes" else "false" 

319 

320 if f.type.endswith("multi"): 

321 choices = [] if ans == " " else ans.split("\n") 

322 form_values[f.var] = f.validate(choices) 

323 else: 

324 form_values[f.var] = f.validate(ans) 

325 if recipient is None: 

326 new_result = await self.__wrap_handler( 

327 msg, 

328 result.handler, 

329 form_values, 

330 session, 

331 msg.get_from(), 

332 *result.handler_args, 

333 **result.handler_kwargs, 

334 ) 

335 new_result = cast(CommandResponseSessionType[Any], new_result) 

336 else: 

337 new_result = await self.__wrap_handler( 

338 msg, 

339 result.handler, 

340 recipient, 

341 form_values, 

342 *result.handler_args, 

343 **result.handler_kwargs, 

344 ) 

345 new_result = cast(CommandResponseRecipientType[Any], new_result) 

346 

347 return await self._handle_result(new_result, msg, session, recipient=recipient) 

348 

349 @staticmethod 

350 async def __wrap_handler( 

351 msg: Message, 

352 f: Callable[P, Awaitable[T] | T], 

353 *a: P.args, 

354 **k: P.kwargs, 

355 ) -> T | None: 

356 try: 

357 if inspect.iscoroutinefunction(f): 

358 return await f(*a, **k) # type:ignore[no-any-return] 

359 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

360 return await f(*a, **k) # type:ignore[misc,no-any-return] 

361 else: 

362 return f(*a, **k) # type:ignore[return-value] 

363 except Exception as e: 

364 log.debug("Error in %s", f, exc_info=e) 

365 reply = msg.reply() 

366 reply["body"] = f"Error: {e}" 

367 reply.send() 

368 return None 

369 

370 def _handle_help(self, msg: Message, *rest: str) -> None: 

371 if len(rest) == 0: 

372 reply = msg.reply() 

373 reply["body"] = self._help(msg.get_from()) 

374 reply.send() 

375 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

376 reply = msg.reply() 

377 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

378 reply.send() 

379 else: 

380 self._not_found(msg, str(rest)) 

381 

382 def _help(self, mfrom: JID) -> str: 

383 session = self.xmpp.get_session_from_jid(mfrom) 

384 

385 msg = "Available commands:" 

386 for c in sorted( 

387 self._commands.values(), 

388 key=lambda co: ( 

389 ( 

390 co.CATEGORY 

391 if isinstance(co.CATEGORY, str) 

392 else ( 

393 co.CATEGORY.name 

394 if isinstance(co.CATEGORY, CommandCategory) 

395 else "" 

396 ) 

397 ), 

398 co.CHAT_COMMAND, 

399 ), 

400 ): 

401 try: 

402 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session) 

403 except XMPPError: 

404 continue 

405 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

406 return msg 

407 

408 def _not_found(self, msg: Message, word: str) -> Never: 

409 e = self.UNKNOWN.format(word) 

410 msg.reply(e).send() 

411 raise XMPPError("item-not-found", e) 

412 

413 async def _handle_recipient( 

414 self, recipient_str: Literal["contact", "room"], msg: Message, *args: str 

415 ) -> None: 

416 session = self.xmpp.get_session_from_jid(msg.get_from()) 

417 

418 recipient_cls = LegacyContact if recipient_str == "contact" else LegacyMUC 

419 

420 if session is None: 

421 raise XMPPError("subscription-required") 

422 

423 if len(args) == 0 or args[0] == "help": 

424 self.xmpp.delivery_receipt.ack(msg) 

425 self._help_recipient(msg, recipient_cls) 

426 return 

427 

428 if len(args) == 1: 

429 self._help_recipient(msg, recipient_cls) 

430 raise XMPPError( 

431 "bad-request", 

432 f"Contact commands require at least two parameters: {recipient_str}_jid_username and command_name", 

433 ) 

434 

435 jid_username, command_name, *rest = args 

436 

437 command = recipient_cls.commands_chat.get(command_name) 

438 if command is None: 

439 raise XMPPError("item-not-found") 

440 

441 if recipient_cls is LegacyContact: 

442 legacy_id = await session.contacts.jid_username_to_legacy_id(jid_username) 

443 recipient = await session.contacts.by_legacy_id(legacy_id) 

444 else: 

445 legacy_id = await session.bookmarks.jid_username_to_legacy_id(jid_username) 

446 recipient = await session.bookmarks.by_legacy_id(legacy_id) 

447 

448 result = await self.__wrap_handler(msg, command.run, recipient, *rest) # type:ignore[arg-type,func-returns-value] 

449 self.xmpp.delivery_receipt.ack(msg) 

450 await self._handle_result(result, msg, session, recipient) 

451 

452 def _help_recipient( 

453 self, msg: Message, recipient_cls: type[AnyContact | AnyMUC] 

454 ) -> None: 

455 msg.reply( 

456 "Available commands:\n" 

457 + "\n".join( 

458 f"{co.CHAT_COMMAND} ({co.NAME}): {co.HELP}" 

459 for co in recipient_cls.commands_chat.values() 

460 ) 

461 ).send() 

462 

463 

464def percent_encode(jid: JID) -> str: 

465 return f"{url_quote(jid.user)}@{jid.server}" 

466 

467 

468log = logging.getLogger(__name__)