Coverage for slidge / command / chat_command.py: 57%

185 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-03-13 22:59 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import functools 

7import inspect 

8import logging 

9from collections.abc import Callable 

10from typing import TYPE_CHECKING, Any, Literal, Never, overload 

11from urllib.parse import quote as url_quote 

12 

13from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

14from slixmpp.exceptions import XMPPError 

15from slixmpp.types import JidStr, MessageTypes 

16 

17from . import Command, CommandResponseType, Confirmation, Form, TableResult 

18from .categories import CommandCategory 

19 

20if TYPE_CHECKING: 

21 from ..core.gateway import BaseGateway 

22 

23 

24class ChatCommandProvider: 

25 UNKNOWN = "Wut? I don't know that command: {}" 

26 

27 def __init__(self, xmpp: "BaseGateway") -> None: 

28 self.xmpp = xmpp 

29 self._keywords = list[str]() 

30 self._commands: dict[str, Command] = {} 

31 self._input_futures = dict[str, asyncio.Future[str]]() 

32 self.xmpp.register_handler( 

33 CoroutineCallback( 

34 "chat_command_handler", 

35 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

36 self._handle_message, # type: ignore 

37 ) 

38 ) 

39 

40 def register(self, command: Command) -> None: 

41 """ 

42 Register a command to be used via chat messages with the gateway 

43 

44 Plugins should not call this, any class subclassing Command should be 

45 automatically added by slidge core. 

46 

47 :param command: the new command 

48 """ 

49 t = command.CHAT_COMMAND 

50 if t in self._commands: 

51 raise RuntimeError("There is already a command triggered by '%s'", t) 

52 self._commands[t] = command 

53 

54 @overload 

55 async def input(self, jid: JidStr, text: str | None = None) -> str: ... 

56 

57 @overload 

58 async def input( 

59 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ... 

60 ) -> asyncio.Future[str]: ... 

61 

62 @overload 

63 async def input( 

64 self, 

65 jid: JidStr, 

66 text: str | None = None, 

67 *, 

68 mtype: MessageTypes = "chat", 

69 timeout: int = 60, 

70 blocking: Literal[True] = True, 

71 **msg_kwargs: Any, 

72 ) -> str: ... 

73 

74 async def input( 

75 self, 

76 jid: JidStr, 

77 text: str | None = None, 

78 *, 

79 mtype: MessageTypes = "chat", 

80 timeout: int = 60, 

81 blocking: bool = True, 

82 **msg_kwargs: Any, 

83 ) -> str | asyncio.Future[str]: 

84 """ 

85 Request arbitrary user input using a simple chat message, and await the result. 

86 

87 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

88 to directly target a user. 

89 

90 NB: When using this, the next message that the user sent to the component will 

91 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

92 Await the coroutine to get its content. 

93 

94 :param jid: The JID we want input from 

95 :param text: A prompt to display for the user 

96 :param mtype: Message type 

97 :param timeout: 

98 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

99 is returned instead of a str 

100 :return: The user's reply 

101 """ 

102 jid = JID(jid) 

103 if text is not None: 

104 self.xmpp.send_message( 

105 mto=jid, 

106 mbody=text, 

107 mtype=mtype, 

108 mfrom=self.xmpp.boundjid.bare, 

109 **msg_kwargs, 

110 ) 

111 f = asyncio.get_event_loop().create_future() 

112 self._input_futures[jid.bare] = f 

113 if not blocking: 

114 return f 

115 try: 

116 await asyncio.wait_for(f, timeout) 

117 except TimeoutError: 

118 self.xmpp.send_message( 

119 mto=jid, 

120 mbody="You took too much time to reply", 

121 mtype=mtype, 

122 mfrom=self.xmpp.boundjid.bare, 

123 ) 

124 del self._input_futures[jid.bare] 

125 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

126 

127 return f.result() 

128 

129 async def _handle_message(self, msg: Message) -> None: 

130 if not msg["body"]: 

131 return 

132 

133 if not msg.get_from().node: 

134 return # ignore component and server messages 

135 

136 f = self._input_futures.pop(msg.get_from().bare, None) 

137 if f is not None: 

138 f.set_result(msg["body"]) 

139 return 

140 

141 c = msg["body"] 

142 first_word, *rest = c.split(" ") 

143 first_word = first_word.lower() 

144 

145 if first_word == "help": 

146 return self._handle_help(msg, *rest) 

147 

148 mfrom = msg.get_from() 

149 

150 command = self._commands.get(first_word) 

151 if command is None: 

152 self._not_found(msg, first_word) 

153 return 

154 

155 try: 

156 session = command.raise_if_not_authorized(mfrom) 

157 except XMPPError as e: 

158 reply = msg.reply() 

159 reply["body"] = e.text 

160 reply.send() 

161 raise 

162 

163 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest) 

164 self.xmpp.delivery_receipt.ack(msg) 

165 return await self._handle_result(result, msg, session) 

166 

167 def __make_uri(self, body: str) -> str: 

168 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

169 

170 async def _handle_result(self, result: CommandResponseType, msg: Message, session): 

171 if isinstance(result, str) or result is None: 

172 reply = msg.reply() 

173 reply["body"] = result or "End of command." 

174 reply.send() 

175 return 

176 

177 if isinstance(result, Form): 

178 try: 

179 return await self.__handle_form(result, msg, session) 

180 except XMPPError as e: 

181 if ( 

182 result.timeout_handler is None 

183 or e.condition != "remote-server-timeout" 

184 ): 

185 raise e 

186 return result.timeout_handler() 

187 

188 if isinstance(result, Confirmation): 

189 yes_or_no = await self.input(msg.get_from(), result.prompt) 

190 if not yes_or_no.lower().startswith("y"): 

191 reply = msg.reply() 

192 reply["body"] = "Canceled" 

193 reply.send() 

194 return 

195 result = await self.__wrap_handler( 

196 msg, 

197 result.handler, 

198 session, 

199 msg.get_from(), 

200 *result.handler_args, 

201 **result.handler_kwargs, 

202 ) 

203 return await self._handle_result(result, msg, session) 

204 

205 if isinstance(result, TableResult): 

206 if len(result.items) == 0: 

207 msg.reply("Empty results").send() 

208 return 

209 

210 body = result.description + "\n" 

211 for item in result.items: 

212 for f in result.fields: 

213 if f.type == "jid-single": 

214 j = JID(item[f.var]) 

215 value = f"xmpp:{percent_encode(j)}" 

216 if result.jids_are_mucs: 

217 value += "?join" 

218 else: 

219 value = item[f.var] # type:ignore 

220 body += f"\n{f.label or f.var}: {value}" 

221 msg.reply(body).send() 

222 

223 async def __handle_form(self, result: Form, msg: Message, session): 

224 form_values = {} 

225 for t in result.title, result.instructions: 

226 if t: 

227 msg.reply(t).send() 

228 for f in result.fields: 

229 if f.type == "fixed": 

230 msg.reply(f"{f.label or f.var}: {f.value}").send() 

231 else: 

232 if f.type == "list-multi": 

233 msg.reply( 

234 "Multiple selection allowed, use new lines as a separator, ie, " 

235 "one selected item per line. To select no item, reply with a space " 

236 "(the punctuation)." 

237 ).send() 

238 if f.options: 

239 for o in f.options: 

240 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send() 

241 if f.value: 

242 msg.reply(f"Default: {f.value}").send() 

243 if f.type == "boolean": 

244 msg.reply("yes: " + self.__make_uri("yes")).send() 

245 msg.reply("no: " + self.__make_uri("no")).send() 

246 

247 ans = await self.xmpp.input( 

248 msg.get_from(), (f.label or f.var) + "? (or 'abort')" 

249 ) 

250 if ans.lower() == "abort": 

251 return await self._handle_result("Command aborted", msg, session) 

252 if f.type == "boolean": 

253 if ans.lower() == "yes": 

254 ans = "true" 

255 else: 

256 ans = "false" 

257 

258 if f.type.endswith("multi"): 

259 choices = [] if ans == " " else ans.split("\n") 

260 form_values[f.var] = f.validate(choices) 

261 else: 

262 form_values[f.var] = f.validate(ans) 

263 result = await self.__wrap_handler( 

264 msg, 

265 result.handler, 

266 form_values, 

267 session, 

268 msg.get_from(), 

269 *result.handler_args, 

270 **result.handler_kwargs, 

271 ) 

272 return await self._handle_result(result, msg, session) 

273 

274 @staticmethod 

275 async def __wrap_handler(msg, f: Callable | functools.partial, *a, **k): 

276 try: 

277 if inspect.iscoroutinefunction(f): 

278 return await f(*a, **k) 

279 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

280 return await f(*a, **k) 

281 else: 

282 return f(*a, **k) 

283 except Exception as e: 

284 log.debug("Error in %s", f, exc_info=e) 

285 reply = msg.reply() 

286 reply["body"] = f"Error: {e}" 

287 reply.send() 

288 

289 def _handle_help(self, msg: Message, *rest: str) -> None: 

290 if len(rest) == 0: 

291 reply = msg.reply() 

292 reply["body"] = self._help(msg.get_from()) 

293 reply.send() 

294 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

295 reply = msg.reply() 

296 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

297 reply.send() 

298 else: 

299 self._not_found(msg, str(rest)) 

300 

301 def _help(self, mfrom: JID) -> str: 

302 session = self.xmpp.get_session_from_jid(mfrom) 

303 

304 msg = "Available commands:" 

305 for c in sorted( 

306 self._commands.values(), 

307 key=lambda co: ( 

308 ( 

309 co.CATEGORY 

310 if isinstance(co.CATEGORY, str) 

311 else ( 

312 co.CATEGORY.name 

313 if isinstance(co.CATEGORY, CommandCategory) 

314 else "" 

315 ) 

316 ), 

317 co.CHAT_COMMAND, 

318 ), 

319 ): 

320 try: 

321 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session) 

322 except XMPPError: 

323 continue 

324 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

325 return msg 

326 

327 def _not_found(self, msg: Message, word: str) -> Never: 

328 e = self.UNKNOWN.format(word) 

329 msg.reply(e).send() 

330 raise XMPPError("item-not-found", e) 

331 

332 

333def percent_encode(jid: JID) -> str: 

334 return f"{url_quote(jid.user)}@{jid.server}" 

335 

336 

337log = logging.getLogger(__name__)