Coverage for slidge / command / adhoc.py: 88%

161 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-02-15 09:02 +0000

1import asyncio 

2import functools 

3import inspect 

4import logging 

5from collections.abc import Callable 

6from functools import partial 

7from typing import TYPE_CHECKING, Any, Optional 

8 

9from slixmpp import JID, Iq 

10from slixmpp.exceptions import XMPPError 

11from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

12from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

13 

14from ..core import config 

15from ..util.util import strip_leading_emoji 

16from . import Command, CommandResponseType, Confirmation, Form, TableResult 

17from .base import FormField 

18from .categories import CommandCategory 

19 

20if TYPE_CHECKING: 

21 from ..core.gateway import BaseGateway 

22 from ..core.session import BaseSession 

23 

24 

25AdhocSessionType = dict[str, Any] 

26 

27 

28class AdhocProvider: 

29 """ 

30 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

31 untyped dict values than slixmpp. 

32 """ 

33 

34 FORM_TIMEOUT = 120 # seconds 

35 

36 def __init__(self, xmpp: "BaseGateway") -> None: 

37 self.xmpp = xmpp 

38 self._commands = dict[str, Command]() 

39 self._categories = dict[str, list[Command]]() 

40 xmpp.plugin["xep_0030"].set_node_handler( 

41 "get_items", 

42 jid=xmpp.boundjid, 

43 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

44 handler=self.get_items, 

45 ) 

46 self.__timeouts: dict[str, asyncio.TimerHandle] = {} 

47 

48 async def __wrap_initial_handler( 

49 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType 

50 ) -> AdhocSessionType: 

51 ifrom = iq.get_from() 

52 session = command.raise_if_not_authorized(ifrom) 

53 result = await self.__wrap_handler(command.run, session, ifrom) 

54 return await self.__handle_result(session, result, adhoc_session) 

55 

56 async def __handle_category_list( 

57 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

58 ) -> AdhocSessionType: 

59 try: 

60 session = self.xmpp.get_session_from_stanza(iq) 

61 except XMPPError: 

62 session = None 

63 commands: dict[str, Command] = {} 

64 for command in self._categories[category.node]: 

65 try: 

66 command.raise_if_not_authorized(iq.get_from()) 

67 except XMPPError: 

68 continue 

69 commands[command.NODE] = command 

70 if len(commands) == 0: 

71 raise XMPPError( 

72 "not-authorized", "There is no command you can run in this category" 

73 ) 

74 return await self.__handle_result( 

75 session, 

76 Form( 

77 category.name, 

78 "", 

79 [ 

80 FormField( 

81 var="command", 

82 label="Command", 

83 type="list-single", 

84 options=[ 

85 { 

86 "label": strip_leading_emoji_if_needed(command.NAME), 

87 "value": command.NODE, 

88 } 

89 for command in commands.values() 

90 ], 

91 ) 

92 ], 

93 partial(self.__handle_category_choice, commands), 

94 ), 

95 adhoc_session, 

96 ) 

97 

98 async def __handle_category_choice( 

99 self, 

100 commands: dict[str, Command], 

101 form_values: dict[str, str], 

102 session: "BaseSession[Any, Any]", 

103 jid: JID, 

104 ): 

105 command = commands[form_values["command"]] 

106 result = await self.__wrap_handler(command.run, session, jid) 

107 return result 

108 

109 async def __handle_result( 

110 self, 

111 session: Optional["BaseSession[Any, Any]"], 

112 result: CommandResponseType, 

113 adhoc_session: AdhocSessionType, 

114 ) -> AdhocSessionType: 

115 if isinstance(result, str) or result is None: 

116 adhoc_session["has_next"] = False 

117 adhoc_session["next"] = None 

118 adhoc_session["payload"] = None 

119 adhoc_session["notes"] = [("info", result or "Success!")] 

120 return adhoc_session 

121 

122 if isinstance(result, Form): 

123 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result) 

124 adhoc_session["has_next"] = True 

125 adhoc_session["payload"] = result.get_xml() 

126 if result.timeout_handler is not None: 

127 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later( 

128 self.FORM_TIMEOUT, 

129 partial( 

130 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"] 

131 ), 

132 ) 

133 return adhoc_session 

134 

135 if isinstance(result, Confirmation): 

136 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

137 adhoc_session["has_next"] = True 

138 adhoc_session["payload"] = result.get_form() 

139 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

140 return adhoc_session 

141 

142 if isinstance(result, TableResult): 

143 adhoc_session["next"] = None 

144 adhoc_session["has_next"] = False 

145 adhoc_session["payload"] = result.get_xml() 

146 return adhoc_session 

147 

148 raise XMPPError("internal-server-error", text="OOPS!") 

149 

150 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None: 

151 try: 

152 del self.xmpp.plugin["xep_0050"].sessions[session_id] 

153 except KeyError: 

154 log.error("Timeout but session could not be found: %s", session_id) 

155 handler() 

156 

157 @staticmethod 

158 async def __wrap_handler(f: Callable | functools.partial, *a, **k): # type: ignore 

159 try: 

160 if inspect.iscoroutinefunction(f): 

161 return await f(*a, **k) 

162 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

163 return await f(*a, **k) 

164 else: 

165 return f(*a, **k) 

166 except XMPPError: 

167 raise 

168 except Exception as e: 

169 log.debug("Exception in %s", f, exc_info=e) 

170 raise XMPPError("internal-server-error", text=str(e)) 

171 

172 async def __wrap_form_handler( 

173 self, 

174 session: Optional["BaseSession[Any, Any]"], 

175 result: Form, 

176 form: SlixForm, 

177 adhoc_session: AdhocSessionType, 

178 ) -> AdhocSessionType: 

179 timer = self.__timeouts.pop(adhoc_session["id"], None) 

180 if timer is not None: 

181 print("canceled", adhoc_session["id"]) 

182 timer.cancel() 

183 form_values = result.get_values(form) 

184 new_result = await self.__wrap_handler( 

185 result.handler, 

186 form_values, 

187 session, 

188 adhoc_session["from"], 

189 *result.handler_args, 

190 **result.handler_kwargs, 

191 ) 

192 

193 return await self.__handle_result(session, new_result, adhoc_session) 

194 

195 async def __wrap_confirmation( 

196 self, 

197 session: Optional["BaseSession[Any, Any]"], 

198 confirmation: Confirmation, 

199 form: SlixForm, 

200 adhoc_session: AdhocSessionType, 

201 ) -> AdhocSessionType: 

202 if form.get_values().get("confirm"): # type: ignore[no-untyped-call] 

203 result = await self.__wrap_handler( 

204 confirmation.handler, 

205 session, 

206 adhoc_session["from"], 

207 *confirmation.handler_args, 

208 **confirmation.handler_kwargs, 

209 ) 

210 if confirmation.success: 

211 result = confirmation.success 

212 else: 

213 result = "You canceled the operation" 

214 

215 return await self.__handle_result(session, result, adhoc_session) 

216 

217 def register(self, command: Command, jid: JID | None = None) -> None: 

218 """ 

219 Register a command as a adhoc command. 

220 

221 this does not need to be called manually, ``BaseGateway`` takes care of 

222 that. 

223 

224 :param command: 

225 :param jid: 

226 """ 

227 if jid is None: 

228 jid = self.xmpp.boundjid 

229 elif not isinstance(jid, JID): 

230 jid = JID(jid) 

231 

232 if (category := command.CATEGORY) is None: 

233 if command.NODE in self._commands: 

234 raise RuntimeError( 

235 "There is already a command for the node '%s'", command.NODE 

236 ) 

237 self._commands[command.NODE] = command 

238 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

239 jid=jid, 

240 node=command.NODE, 

241 name=strip_leading_emoji_if_needed(command.NAME), 

242 handler=partial(self.__wrap_initial_handler, command), 

243 ) 

244 else: 

245 if isinstance(category, str): 

246 category = CommandCategory(category, category) 

247 node = category.node 

248 name = category.name 

249 if node not in self._categories: 

250 self._categories[node] = list[Command]() 

251 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

252 jid=jid, 

253 node=node, 

254 name=strip_leading_emoji_if_needed(name), 

255 handler=partial(self.__handle_category_list, category), 

256 ) 

257 self._categories[node].append(command) 

258 

259 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

260 """ 

261 Get items for a disco query 

262 

263 :param jid: the entity that should return its items 

264 :param node: which command node is requested 

265 :param iq: the disco query IQ 

266 :return: commands accessible to the given JID will be listed 

267 """ 

268 ifrom = iq.get_from() 

269 ifrom_str = str(ifrom) 

270 if ( 

271 not self.xmpp.jid_validator.match(ifrom_str) 

272 and ifrom_str not in config.ADMINS 

273 ): 

274 raise XMPPError( 

275 "forbidden", 

276 "You are not authorized to execute adhoc commands on this gateway. " 

277 "If this is unexpected, ask your administrator to verify that " 

278 "'user-jid-validator' is correctly set in slidge's configuration.", 

279 ) 

280 

281 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

282 log.debug("Static items: %r", all_items) 

283 if not all_items: 

284 return DiscoItems() 

285 

286 session = self.xmpp.get_session_from_jid(ifrom) 

287 

288 filtered_items = DiscoItems() 

289 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

290 for item in all_items: 

291 authorized = True 

292 if item["node"] in self._categories: 

293 for command in self._categories[item["node"]]: 

294 try: 

295 command.raise_if_not_authorized( 

296 ifrom, fetch_session=False, session=session 

297 ) 

298 except XMPPError: 

299 authorized = False 

300 else: 

301 authorized = True 

302 break 

303 else: 

304 try: 

305 self._commands[item["node"]].raise_if_not_authorized( 

306 ifrom, fetch_session=False, session=session 

307 ) 

308 except XMPPError: 

309 authorized = False 

310 

311 if authorized: 

312 filtered_items.append(item) 

313 

314 return filtered_items 

315 

316 

317def strip_leading_emoji_if_needed(text: str) -> str: 

318 if config.STRIP_LEADING_EMOJI_ADHOC: 

319 return strip_leading_emoji(text) 

320 return text 

321 

322 

323log = logging.getLogger(__name__)