Coverage for slidge / command / chat_command.py: 57%

184 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-02-15 09:02 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import functools 

7import inspect 

8import logging 

9from collections.abc import Callable 

10from typing import TYPE_CHECKING, Any, Literal, overload 

11from urllib.parse import quote as url_quote 

12 

13from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

14from slixmpp.exceptions import XMPPError 

15from slixmpp.types import JidStr, MessageTypes 

16 

17from . import Command, CommandResponseType, Confirmation, Form, TableResult 

18from .categories import CommandCategory 

19 

20if TYPE_CHECKING: 

21 from ..core.gateway import BaseGateway 

22 

23 

24class ChatCommandProvider: 

25 UNKNOWN = "Wut? I don't know that command: {}" 

26 

27 def __init__(self, xmpp: "BaseGateway") -> None: 

28 self.xmpp = xmpp 

29 self._keywords = list[str]() 

30 self._commands: dict[str, Command] = {} 

31 self._input_futures = dict[str, asyncio.Future[str]]() 

32 self.xmpp.register_handler( 

33 CoroutineCallback( 

34 "chat_command_handler", 

35 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

36 self._handle_message, # type: ignore 

37 ) 

38 ) 

39 

40 def register(self, command: Command): 

41 """ 

42 Register a command to be used via chat messages with the gateway 

43 

44 Plugins should not call this, any class subclassing Command should be 

45 automatically added by slidge core. 

46 

47 :param command: the new command 

48 """ 

49 t = command.CHAT_COMMAND 

50 if t in self._commands: 

51 raise RuntimeError("There is already a command triggered by '%s'", t) 

52 self._commands[t] = command 

53 

54 @overload 

55 async def input(self, jid: JidStr, text: str | None = None) -> str: ... 

56 

57 @overload 

58 async def input( 

59 self, jid: JidStr, text: str | None = None, *, blocking: Literal[False] = ... 

60 ) -> asyncio.Future[str]: ... 

61 

62 @overload 

63 async def input( 

64 self, 

65 jid: JidStr, 

66 text: str | None = None, 

67 *, 

68 mtype: MessageTypes = "chat", 

69 timeout: int = 60, 

70 blocking: Literal[True] = True, 

71 **msg_kwargs: Any, 

72 ) -> str: ... 

73 

74 async def input( 

75 self, 

76 jid: JidStr, 

77 text: str | None = None, 

78 *, 

79 mtype: MessageTypes = "chat", 

80 timeout: int = 60, 

81 blocking: bool = True, 

82 **msg_kwargs: Any, 

83 ) -> str | asyncio.Future[str]: 

84 """ 

85 Request arbitrary user input using a simple chat message, and await the result. 

86 

87 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

88 to directly target a user. 

89 

90 NB: When using this, the next message that the user sent to the component will 

91 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

92 Await the coroutine to get its content. 

93 

94 :param jid: The JID we want input from 

95 :param text: A prompt to display for the user 

96 :param mtype: Message type 

97 :param timeout: 

98 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

99 is returned instead of a str 

100 :return: The user's reply 

101 """ 

102 jid = JID(jid) 

103 if text is not None: 

104 self.xmpp.send_message( 

105 mto=jid, 

106 mbody=text, 

107 mtype=mtype, 

108 mfrom=self.xmpp.boundjid.bare, 

109 **msg_kwargs, 

110 ) 

111 f = asyncio.get_event_loop().create_future() 

112 self._input_futures[jid.bare] = f 

113 if not blocking: 

114 return f 

115 try: 

116 await asyncio.wait_for(f, timeout) 

117 except TimeoutError: 

118 self.xmpp.send_message( 

119 mto=jid, 

120 mbody="You took too much time to reply", 

121 mtype=mtype, 

122 mfrom=self.xmpp.boundjid.bare, 

123 ) 

124 del self._input_futures[jid.bare] 

125 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

126 

127 return f.result() 

128 

129 async def _handle_message(self, msg: Message): 

130 if not msg["body"]: 

131 return 

132 

133 if not msg.get_from().node: 

134 return # ignore component and server messages 

135 

136 f = self._input_futures.pop(msg.get_from().bare, None) 

137 if f is not None: 

138 f.set_result(msg["body"]) 

139 return 

140 

141 c = msg["body"] 

142 first_word, *rest = c.split(" ") 

143 first_word = first_word.lower() 

144 

145 if first_word == "help": 

146 return self._handle_help(msg, *rest) 

147 

148 mfrom = msg.get_from() 

149 

150 command = self._commands.get(first_word) 

151 if command is None: 

152 return self._not_found(msg, first_word) 

153 

154 try: 

155 session = command.raise_if_not_authorized(mfrom) 

156 except XMPPError as e: 

157 reply = msg.reply() 

158 reply["body"] = e.text 

159 reply.send() 

160 raise 

161 

162 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest) 

163 self.xmpp.delivery_receipt.ack(msg) 

164 return await self._handle_result(result, msg, session) 

165 

166 def __make_uri(self, body: str) -> str: 

167 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

168 

169 async def _handle_result(self, result: CommandResponseType, msg: Message, session): 

170 if isinstance(result, str) or result is None: 

171 reply = msg.reply() 

172 reply["body"] = result or "End of command." 

173 reply.send() 

174 return 

175 

176 if isinstance(result, Form): 

177 try: 

178 return await self.__handle_form(result, msg, session) 

179 except XMPPError as e: 

180 if ( 

181 result.timeout_handler is None 

182 or e.condition != "remote-server-timeout" 

183 ): 

184 raise e 

185 return result.timeout_handler() 

186 

187 if isinstance(result, Confirmation): 

188 yes_or_no = await self.input(msg.get_from(), result.prompt) 

189 if not yes_or_no.lower().startswith("y"): 

190 reply = msg.reply() 

191 reply["body"] = "Canceled" 

192 reply.send() 

193 return 

194 result = await self.__wrap_handler( 

195 msg, 

196 result.handler, 

197 session, 

198 msg.get_from(), 

199 *result.handler_args, 

200 **result.handler_kwargs, 

201 ) 

202 return await self._handle_result(result, msg, session) 

203 

204 if isinstance(result, TableResult): 

205 if len(result.items) == 0: 

206 msg.reply("Empty results").send() 

207 return 

208 

209 body = result.description + "\n" 

210 for item in result.items: 

211 for f in result.fields: 

212 if f.type == "jid-single": 

213 j = JID(item[f.var]) 

214 value = f"xmpp:{percent_encode(j)}" 

215 if result.jids_are_mucs: 

216 value += "?join" 

217 else: 

218 value = item[f.var] # type:ignore 

219 body += f"\n{f.label or f.var}: {value}" 

220 msg.reply(body).send() 

221 

222 async def __handle_form(self, result: Form, msg: Message, session): 

223 form_values = {} 

224 for t in result.title, result.instructions: 

225 if t: 

226 msg.reply(t).send() 

227 for f in result.fields: 

228 if f.type == "fixed": 

229 msg.reply(f"{f.label or f.var}: {f.value}").send() 

230 else: 

231 if f.type == "list-multi": 

232 msg.reply( 

233 "Multiple selection allowed, use new lines as a separator, ie, " 

234 "one selected item per line. To select no item, reply with a space " 

235 "(the punctuation)." 

236 ).send() 

237 if f.options: 

238 for o in f.options: 

239 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send() 

240 if f.value: 

241 msg.reply(f"Default: {f.value}").send() 

242 if f.type == "boolean": 

243 msg.reply("yes: " + self.__make_uri("yes")).send() 

244 msg.reply("no: " + self.__make_uri("no")).send() 

245 

246 ans = await self.xmpp.input( 

247 msg.get_from(), (f.label or f.var) + "? (or 'abort')" 

248 ) 

249 if ans.lower() == "abort": 

250 return await self._handle_result("Command aborted", msg, session) 

251 if f.type == "boolean": 

252 if ans.lower() == "yes": 

253 ans = "true" 

254 else: 

255 ans = "false" 

256 

257 if f.type.endswith("multi"): 

258 choices = [] if ans == " " else ans.split("\n") 

259 form_values[f.var] = f.validate(choices) 

260 else: 

261 form_values[f.var] = f.validate(ans) 

262 result = await self.__wrap_handler( 

263 msg, 

264 result.handler, 

265 form_values, 

266 session, 

267 msg.get_from(), 

268 *result.handler_args, 

269 **result.handler_kwargs, 

270 ) 

271 return await self._handle_result(result, msg, session) 

272 

273 @staticmethod 

274 async def __wrap_handler(msg, f: Callable | functools.partial, *a, **k): 

275 try: 

276 if inspect.iscoroutinefunction(f): 

277 return await f(*a, **k) 

278 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

279 return await f(*a, **k) 

280 else: 

281 return f(*a, **k) 

282 except Exception as e: 

283 log.debug("Error in %s", f, exc_info=e) 

284 reply = msg.reply() 

285 reply["body"] = f"Error: {e}" 

286 reply.send() 

287 

288 def _handle_help(self, msg: Message, *rest) -> None: 

289 if len(rest) == 0: 

290 reply = msg.reply() 

291 reply["body"] = self._help(msg.get_from()) 

292 reply.send() 

293 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

294 reply = msg.reply() 

295 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

296 reply.send() 

297 else: 

298 self._not_found(msg, str(rest)) 

299 

300 def _help(self, mfrom: JID): 

301 session = self.xmpp.get_session_from_jid(mfrom) 

302 

303 msg = "Available commands:" 

304 for c in sorted( 

305 self._commands.values(), 

306 key=lambda co: ( 

307 ( 

308 co.CATEGORY 

309 if isinstance(co.CATEGORY, str) 

310 else ( 

311 co.CATEGORY.name 

312 if isinstance(co.CATEGORY, CommandCategory) 

313 else "" 

314 ) 

315 ), 

316 co.CHAT_COMMAND, 

317 ), 

318 ): 

319 try: 

320 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session) 

321 except XMPPError: 

322 continue 

323 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

324 return msg 

325 

326 def _not_found(self, msg: Message, word: str): 

327 e = self.UNKNOWN.format(word) 

328 msg.reply(e).send() 

329 raise XMPPError("item-not-found", e) 

330 

331 

332def percent_encode(jid: JID) -> str: 

333 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore 

334 

335 

336log = logging.getLogger(__name__)