Coverage for slidge/command/chat_command.py: 58%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1# Handle slidge commands by exchanging chat messages with the gateway components. 

2 

3# Ad-hoc methods should provide a better UX, but some clients do not support them, 

4# so this is mostly a fallback. 

5import asyncio 

6import functools 

7import logging 

8from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, overload 

9from urllib.parse import quote as url_quote 

10 

11from slixmpp import JID, CoroutineCallback, Message, StanzaPath 

12from slixmpp.exceptions import XMPPError 

13from slixmpp.types import JidStr, MessageTypes 

14 

15from . import Command, CommandResponseType, Confirmation, Form, TableResult 

16from .categories import CommandCategory 

17 

18if TYPE_CHECKING: 

19 from ..core.gateway import BaseGateway 

20 

21 

22class ChatCommandProvider: 

23 UNKNOWN = "Wut? I don't know that command: {}" 

24 

25 def __init__(self, xmpp: "BaseGateway"): 

26 self.xmpp = xmpp 

27 self._keywords = list[str]() 

28 self._commands: dict[str, Command] = {} 

29 self._input_futures = dict[str, asyncio.Future[str]]() 

30 self.xmpp.register_handler( 

31 CoroutineCallback( 

32 "chat_command_handler", 

33 StanzaPath(f"message@to={self.xmpp.boundjid.bare}"), 

34 self._handle_message, # type: ignore 

35 ) 

36 ) 

37 

38 def register(self, command: Command): 

39 """ 

40 Register a command to be used via chat messages with the gateway 

41 

42 Plugins should not call this, any class subclassing Command should be 

43 automatically added by slidge core. 

44 

45 :param command: the new command 

46 """ 

47 t = command.CHAT_COMMAND 

48 if t in self._commands: 

49 raise RuntimeError("There is already a command triggered by '%s'", t) 

50 self._commands[t] = command 

51 

52 @overload 

53 async def input( 

54 self, jid: JidStr, text: Optional[str], blocking: Literal[False] 

55 ) -> asyncio.Future[str]: ... 

56 

57 @overload 

58 async def input( 

59 self, 

60 jid: JidStr, 

61 text: Optional[str], 

62 mtype: MessageTypes = ..., 

63 blocking: Literal[True] = ..., 

64 ) -> str: ... 

65 

66 async def input( 

67 self, 

68 jid, 

69 text=None, 

70 mtype="chat", 

71 timeout=60, 

72 blocking=True, 

73 **msg_kwargs, 

74 ): 

75 """ 

76 Request arbitrary user input using a simple chat message, and await the result. 

77 

78 You shouldn't need to call directly bust instead use :meth:`.BaseSession.input` 

79 to directly target a user. 

80 

81 NB: When using this, the next message that the user sent to the component will 

82 not be transmitted to :meth:`.BaseGateway.on_gateway_message`, but rather intercepted. 

83 Await the coroutine to get its content. 

84 

85 :param jid: The JID we want input from 

86 :param text: A prompt to display for the user 

87 :param mtype: Message type 

88 :param timeout: 

89 :param blocking: If set to False, timeout has no effect and an :class:`asyncio.Future` 

90 is returned instead of a str 

91 :return: The user's reply 

92 """ 

93 jid = JID(jid) 

94 if text is not None: 

95 self.xmpp.send_message( 

96 mto=jid, 

97 mbody=text, 

98 mtype=mtype, 

99 mfrom=self.xmpp.boundjid.bare, 

100 **msg_kwargs, 

101 ) 

102 f = asyncio.get_event_loop().create_future() 

103 self._input_futures[jid.bare] = f 

104 if not blocking: 

105 return f 

106 try: 

107 await asyncio.wait_for(f, timeout) 

108 except asyncio.TimeoutError: 

109 self.xmpp.send_message( 

110 mto=jid, 

111 mbody="You took too much time to reply", 

112 mtype=mtype, 

113 mfrom=self.xmpp.boundjid.bare, 

114 ) 

115 del self._input_futures[jid.bare] 

116 raise XMPPError("remote-server-timeout", "You took too much time to reply") 

117 

118 return f.result() 

119 

120 async def _handle_message(self, msg: Message): 

121 if not msg["body"]: 

122 return 

123 

124 if not msg.get_from().node: 

125 return # ignore component and server messages 

126 

127 f = self._input_futures.pop(msg.get_from().bare, None) 

128 if f is not None: 

129 f.set_result(msg["body"]) 

130 return 

131 

132 c = msg["body"] 

133 first_word, *rest = c.split(" ") 

134 first_word = first_word.lower() 

135 

136 if first_word == "help": 

137 return self._handle_help(msg, *rest) 

138 

139 mfrom = msg.get_from() 

140 

141 command = self._commands.get(first_word) 

142 if command is None: 

143 return self._not_found(msg, first_word) 

144 

145 try: 

146 session = command.raise_if_not_authorized(mfrom) 

147 except XMPPError as e: 

148 reply = msg.reply() 

149 reply["body"] = e.text 

150 reply.send() 

151 raise 

152 

153 result = await self.__wrap_handler(msg, command.run, session, mfrom, *rest) 

154 self.xmpp.delivery_receipt.ack(msg) 

155 return await self._handle_result(result, msg, session) 

156 

157 def __make_uri(self, body: str) -> str: 

158 return f"xmpp:{self.xmpp.boundjid.bare}?message;body={body}" 

159 

160 async def _handle_result(self, result: CommandResponseType, msg: Message, session): 

161 if isinstance(result, str) or result is None: 

162 reply = msg.reply() 

163 reply["body"] = result or "End of command." 

164 reply.send() 

165 return 

166 

167 if isinstance(result, Form): 

168 form_values = {} 

169 for t in result.title, result.instructions: 

170 if t: 

171 msg.reply(t).send() 

172 for f in result.fields: 

173 if f.type == "fixed": 

174 msg.reply(f"{f.label or f.var}: {f.value}").send() 

175 else: 

176 if f.type == "list-multi": 

177 msg.reply( 

178 "Multiple selection allowed, use new lines as a separator, ie, " 

179 "one selected item per line. To select no item, reply with a space " 

180 "(the punctuation)." 

181 ).send() 

182 if f.options: 

183 for o in f.options: 

184 msg.reply( 

185 f"{o['label']}: {self.__make_uri(o['value'])}" 

186 ).send() 

187 if f.value: 

188 msg.reply(f"Default: {f.value}").send() 

189 if f.type == "boolean": 

190 msg.reply("yes: " + self.__make_uri("yes")).send() 

191 msg.reply("no: " + self.__make_uri("no")).send() 

192 

193 ans = await self.xmpp.input( 

194 msg.get_from(), (f.label or f.var) + "? (or 'abort')" 

195 ) 

196 if ans.lower() == "abort": 

197 return await self._handle_result( 

198 "Command aborted", msg, session 

199 ) 

200 if f.type == "boolean": 

201 if ans.lower() == "yes": 

202 ans = "true" 

203 else: 

204 ans = "false" 

205 

206 if f.type.endswith("multi"): 

207 choices = [] if ans == " " else ans.split("\n") 

208 form_values[f.var] = f.validate(choices) 

209 else: 

210 form_values[f.var] = f.validate(ans) 

211 result = await self.__wrap_handler( 

212 msg, 

213 result.handler, 

214 form_values, 

215 session, 

216 msg.get_from(), 

217 *result.handler_args, 

218 **result.handler_kwargs, 

219 ) 

220 return await self._handle_result(result, msg, session) 

221 

222 if isinstance(result, Confirmation): 

223 yes_or_no = await self.input(msg.get_from(), result.prompt) 

224 if not yes_or_no.lower().startswith("y"): 

225 reply = msg.reply() 

226 reply["body"] = "Canceled" 

227 reply.send() 

228 return 

229 result = await self.__wrap_handler( 

230 msg, 

231 result.handler, 

232 session, 

233 msg.get_from(), 

234 *result.handler_args, 

235 **result.handler_kwargs, 

236 ) 

237 return await self._handle_result(result, msg, session) 

238 

239 if isinstance(result, TableResult): 

240 if len(result.items) == 0: 

241 msg.reply("Empty results").send() 

242 return 

243 

244 body = result.description + "\n" 

245 for item in result.items: 

246 for f in result.fields: 

247 if f.type == "jid-single": 

248 j = JID(item[f.var]) 

249 value = f"xmpp:{percent_encode(j)}" 

250 if result.jids_are_mucs: 

251 value += "?join" 

252 else: 

253 value = item[f.var] # type:ignore 

254 body += f"\n{f.label or f.var}: {value}" 

255 msg.reply(body).send() 

256 

257 @staticmethod 

258 async def __wrap_handler(msg, f: Union[Callable, functools.partial], *a, **k): 

259 try: 

260 if asyncio.iscoroutinefunction(f): 

261 return await f(*a, **k) 

262 elif hasattr(f, "func") and asyncio.iscoroutinefunction(f.func): 

263 return await f(*a, **k) 

264 else: 

265 return f(*a, **k) 

266 except Exception as e: 

267 log.debug("Error in %s", f, exc_info=e) 

268 reply = msg.reply() 

269 reply["body"] = f"Error: {e}" 

270 reply.send() 

271 

272 def _handle_help(self, msg: Message, *rest): 

273 if len(rest) == 0: 

274 reply = msg.reply() 

275 reply["body"] = self._help(msg.get_from()) 

276 reply.send() 

277 elif len(rest) == 1 and (command := self._commands.get(rest[0])): 

278 reply = msg.reply() 

279 reply["body"] = f"{command.CHAT_COMMAND}: {command.NAME}\n{command.HELP}" 

280 reply.send() 

281 else: 

282 self._not_found(msg, str(rest)) 

283 

284 def _help(self, mfrom: JID): 

285 msg = "Available commands:" 

286 for c in sorted( 

287 self._commands.values(), 

288 key=lambda co: ( 

289 ( 

290 co.CATEGORY 

291 if isinstance(co.CATEGORY, str) 

292 else ( 

293 co.CATEGORY.name 

294 if isinstance(co.CATEGORY, CommandCategory) 

295 else "" 

296 ) 

297 ), 

298 co.CHAT_COMMAND, 

299 ), 

300 ): 

301 try: 

302 c.raise_if_not_authorized(mfrom) 

303 except XMPPError: 

304 continue 

305 msg += f"\n{c.CHAT_COMMAND} -- {c.NAME}" 

306 return msg 

307 

308 def _not_found(self, msg: Message, word: str): 

309 e = self.UNKNOWN.format(word) 

310 msg.reply(e).send() 

311 raise XMPPError("item-not-found", e) 

312 

313 

314def percent_encode(jid: JID): 

315 return f"{url_quote(jid.user)}@{jid.server}" # type:ignore 

316 

317 

318log = logging.getLogger(__name__)