Coverage for slidge/command/chat_command.py: 57%

183 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-26 19:34 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import functools 

7import inspect 

8import logging 

9from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, overload 

10from urllib.parse import quote as url_quote 

11 

12from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

13from slixmpp.exceptions import XMPPError 

14from slixmpp.types import JidStr, MessageTypes 

15 

16from . import Command, CommandResponseType, Confirmation, Form, TableResult 

17from .categories import CommandCategory 

18 

19if TYPE_CHECKING: 

20 from ..core.gateway import BaseGateway 

21 

22 

23class ChatCommandProvider: 

24 UNKNOWN = "Wut? I don't know that command: {}" 

25 

26 def __init__(self, xmpp: "BaseGateway") -> None: 

27 self.xmpp = xmpp 

28 self._keywords = list[str]() 

29 self._commands: dict[str, Command] = {} 

30 self._input_futures = dict[str, asyncio.Future[str]]() 

31 self.xmpp.register_handler( 

32 CoroutineCallback( 

33 "chat_command_handler", 

34 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

35 self._handle_message, # type: ignore 

36 ) 

37 ) 

38 

39 def register(self, command: Command): 

40 """ 

41 Register a command to be used via chat messages with the gateway 

42 

43 Plugins should not call this, any class subclassing Command should be 

44 automatically added by slidge core. 

45 

46 :param command: the new command 

47 """ 

48 t = command.CHAT_COMMAND 

49 if t in self._commands: 

50 raise RuntimeError("There is already a command triggered by '%s'", t) 

51 self._commands[t] = command 

52 

53 @overload 

54 async def input(self, jid: JidStr, text: Optional[str] = None) -> str: ... 

55 

56 @overload 

57 async def input( 

58 self, jid: JidStr, text: Optional[str] = None, *, blocking: Literal[False] = ... 

59 ) -> asyncio.Future[str]: ... 

60 

61 @overload 

62 async def input( 

63 self, 

64 jid: JidStr, 

65 text: str | None = None, 

66 *, 

67 mtype: MessageTypes = "chat", 

68 timeout: int = 60, 

69 blocking: Literal[True] = True, 

70 **msg_kwargs: Any, 

71 ) -> str: ... 

72 

73 async def input( 

74 self, 

75 jid: JidStr, 

76 text: str | None = None, 

77 *, 

78 mtype: MessageTypes = "chat", 

79 timeout: int = 60, 

80 blocking: bool = True, 

81 **msg_kwargs: Any, 

82 ) -> str | asyncio.Future[str]: 

83 """ 

84 Request arbitrary user input using a simple chat message, and await the result. 

85 

86 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

87 to directly target a user. 

88 

89 NB: When using this, the next message that the user sent to the component will 

90 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

91 Await the coroutine to get its content. 

92 

93 :param jid: The JID we want input from 

94 :param text: A prompt to display for the user 

95 :param mtype: Message type 

96 :param timeout: 

97 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

98 is returned instead of a str 

99 :return: The user's reply 

100 """ 

101 jid = JID(jid) 

102 if text is not None: 

103 self.xmpp.send_message( 

104 mto=jid, 

105 mbody=text, 

106 mtype=mtype, 

107 mfrom=self.xmpp.boundjid.bare, 

108 **msg_kwargs, 

109 ) 

110 f = asyncio.get_event_loop().create_future() 

111 self._input_futures[jid.bare] = f 

112 if not blocking: 

113 return f 

114 try: 

115 await asyncio.wait_for(f, timeout) 

116 except asyncio.TimeoutError: 

117 self.xmpp.send_message( 

118 mto=jid, 

119 mbody="You took too much time to reply", 

120 mtype=mtype, 

121 mfrom=self.xmpp.boundjid.bare, 

122 ) 

123 del self._input_futures[jid.bare] 

124 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

125 

126 return f.result() 

127 

128 async def _handle_message(self, msg: Message): 

129 if not msg["body"]: 

130 return 

131 

132 if not msg.get_from().node: 

133 return # ignore component and server messages 

134 

135 f = self._input_futures.pop(msg.get_from().bare, None) 

136 if f is not None: 

137 f.set_result(msg["body"]) 

138 return 

139 

140 c = msg["body"] 

141 first_word, *rest = c.split(" ") 

142 first_word = first_word.lower() 

143 

144 if first_word == "help": 

145 return self._handle_help(msg, *rest) 

146 

147 mfrom = msg.get_from() 

148 

149 command = self._commands.get(first_word) 

150 if command is None: 

151 return self._not_found(msg, first_word) 

152 

153 try: 

154 session = command.raise_if_not_authorized(mfrom) 

155 except XMPPError as e: 

156 reply = msg.reply() 

157 reply["body"] = e.text 

158 reply.send() 

159 raise 

160 

161 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest) 

162 self.xmpp.delivery_receipt.ack(msg) 

163 return await self._handle_result(result, msg, session) 

164 

165 def __make_uri(self, body: str) -> str: 

166 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

167 

168 async def _handle_result(self, result: CommandResponseType, msg: Message, session): 

169 if isinstance(result, str) or result is None: 

170 reply = msg.reply() 

171 reply["body"] = result or "End of command." 

172 reply.send() 

173 return 

174 

175 if isinstance(result, Form): 

176 try: 

177 return await self.__handle_form(result, msg, session) 

178 except XMPPError as e: 

179 if ( 

180 result.timeout_handler is None 

181 or e.condition != "remote-server-timeout" 

182 ): 

183 raise e 

184 return result.timeout_handler() 

185 

186 if isinstance(result, Confirmation): 

187 yes_or_no = await self.input(msg.get_from(), result.prompt) 

188 if not yes_or_no.lower().startswith("y"): 

189 reply = msg.reply() 

190 reply["body"] = "Canceled" 

191 reply.send() 

192 return 

193 result = await self.__wrap_handler( 

194 msg, 

195 result.handler, 

196 session, 

197 msg.get_from(), 

198 *result.handler_args, 

199 **result.handler_kwargs, 

200 ) 

201 return await self._handle_result(result, msg, session) 

202 

203 if isinstance(result, TableResult): 

204 if len(result.items) == 0: 

205 msg.reply("Empty results").send() 

206 return 

207 

208 body = result.description + "\n" 

209 for item in result.items: 

210 for f in result.fields: 

211 if f.type == "jid-single": 

212 j = JID(item[f.var]) 

213 value = f"xmpp:{percent_encode(j)}" 

214 if result.jids_are_mucs: 

215 value += "?join" 

216 else: 

217 value = item[f.var] # type:ignore 

218 body += f"\n{f.label or f.var}: {value}" 

219 msg.reply(body).send() 

220 

221 async def __handle_form(self, result: Form, msg: Message, session): 

222 form_values = {} 

223 for t in result.title, result.instructions: 

224 if t: 

225 msg.reply(t).send() 

226 for f in result.fields: 

227 if f.type == "fixed": 

228 msg.reply(f"{f.label or f.var}: {f.value}").send() 

229 else: 

230 if f.type == "list-multi": 

231 msg.reply( 

232 "Multiple selection allowed, use new lines as a separator, ie, " 

233 "one selected item per line. To select no item, reply with a space " 

234 "(the punctuation)." 

235 ).send() 

236 if f.options: 

237 for o in f.options: 

238 msg.reply(f"{o['label']}: {self.__make_uri(o['value'])}").send() 

239 if f.value: 

240 msg.reply(f"Default: {f.value}").send() 

241 if f.type == "boolean": 

242 msg.reply("yes: " + self.__make_uri("yes")).send() 

243 msg.reply("no: " + self.__make_uri("no")).send() 

244 

245 ans = await self.xmpp.input( 

246 msg.get_from(), (f.label or f.var) + "? (or 'abort')" 

247 ) 

248 if ans.lower() == "abort": 

249 return await self._handle_result("Command aborted", msg, session) 

250 if f.type == "boolean": 

251 if ans.lower() == "yes": 

252 ans = "true" 

253 else: 

254 ans = "false" 

255 

256 if f.type.endswith("multi"): 

257 choices = [] if ans == " " else ans.split("\n") 

258 form_values[f.var] = f.validate(choices) 

259 else: 

260 form_values[f.var] = f.validate(ans) 

261 result = await self.__wrap_handler( 

262 msg, 

263 result.handler, 

264 form_values, 

265 session, 

266 msg.get_from(), 

267 *result.handler_args, 

268 **result.handler_kwargs, 

269 ) 

270 return await self._handle_result(result, msg, session) 

271 

272 @staticmethod 

273 async def __wrap_handler(msg, f: Union[Callable, functools.partial], *a, **k): 

274 try: 

275 if inspect.iscoroutinefunction(f): 

276 return await f(*a, **k) 

277 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

278 return await f(*a, **k) 

279 else: 

280 return f(*a, **k) 

281 except Exception as e: 

282 log.debug("Error in %s", f, exc_info=e) 

283 reply = msg.reply() 

284 reply["body"] = f"Error: {e}" 

285 reply.send() 

286 

287 def _handle_help(self, msg: Message, *rest) -> None: 

288 if len(rest) == 0: 

289 reply = msg.reply() 

290 reply["body"] = self._help(msg.get_from()) 

291 reply.send() 

292 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

293 reply = msg.reply() 

294 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

295 reply.send() 

296 else: 

297 self._not_found(msg, str(rest)) 

298 

299 def _help(self, mfrom: JID): 

300 session = self.xmpp.get_session_from_jid(mfrom) 

301 

302 msg = "Available commands:" 

303 for c in sorted( 

304 self._commands.values(), 

305 key=lambda co: ( 

306 ( 

307 co.CATEGORY 

308 if isinstance(co.CATEGORY, str) 

309 else ( 

310 co.CATEGORY.name 

311 if isinstance(co.CATEGORY, CommandCategory) 

312 else "" 

313 ) 

314 ), 

315 co.CHAT_COMMAND, 

316 ), 

317 ): 

318 try: 

319 c.raise_if_not_authorized(mfrom, fetch_session=False, session=session) 

320 except XMPPError: 

321 continue 

322 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

323 return msg 

324 

325 def _not_found(self, msg: Message, word: str): 

326 e = self.UNKNOWN.format(word) 

327 msg.reply(e).send() 

328 raise XMPPError("item-not-found", e) 

329 

330 

331def percent_encode(jid: JID) -> str: 

332 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore 

333 

334 

335log = logging.getLogger(__name__)