Coverage for slidge / command / adhoc.py: 88%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-06 15:18 +0000

1import asyncio 

2import functools 

3import inspect 

4import logging 

5from functools import partial 

6from typing import TYPE_CHECKING, Any, Callable, Optional, Union 

7 

8from slixmpp import JID, Iq 

9from slixmpp.exceptions import XMPPError 

10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

12 

13from ..core import config 

14from ..util.util import strip_leading_emoji 

15from . import Command, CommandResponseType, Confirmation, Form, TableResult 

16from .base import FormField 

17from .categories import CommandCategory 

18 

19if TYPE_CHECKING: 

20 from ..core.gateway import BaseGateway 

21 from ..core.session import BaseSession 

22 

23 

24AdhocSessionType = dict[str, Any] 

25 

26 

27class AdhocProvider: 

28 """ 

29 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

30 untyped dict values than slixmpp. 

31 """ 

32 

33 FORM_TIMEOUT = 120 # seconds 

34 

35 def __init__(self, xmpp: "BaseGateway") -> None: 

36 self.xmpp = xmpp 

37 self._commands = dict[str, Command]() 

38 self._categories = dict[str, list[Command]]() 

39 xmpp.plugin["xep_0030"].set_node_handler( 

40 "get_items", 

41 jid=xmpp.boundjid, 

42 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

43 handler=self.get_items, 

44 ) 

45 self.__timeouts: dict[str, asyncio.TimerHandle] = {} 

46 

47 async def __wrap_initial_handler( 

48 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType 

49 ) -> AdhocSessionType: 

50 ifrom = iq.get_from() 

51 session = command.raise_if_not_authorized(ifrom) 

52 result = await self.__wrap_handler(command.run, session, ifrom) 

53 return await self.__handle_result(session, result, adhoc_session) 

54 

55 async def __handle_category_list( 

56 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

57 ) -> AdhocSessionType: 

58 try: 

59 session = self.xmpp.get_session_from_stanza(iq) 

60 except XMPPError: 

61 session = None 

62 commands: dict[str, Command] = {} 

63 for command in self._categories[category.node]: 

64 try: 

65 command.raise_if_not_authorized(iq.get_from()) 

66 except XMPPError: 

67 continue 

68 commands[command.NODE] = command 

69 if len(commands) == 0: 

70 raise XMPPError( 

71 "not-authorized", "There is no command you can run in this category" 

72 ) 

73 return await self.__handle_result( 

74 session, 

75 Form( 

76 category.name, 

77 "", 

78 [ 

79 FormField( 

80 var="command", 

81 label="Command", 

82 type="list-single", 

83 options=[ 

84 { 

85 "label": strip_leading_emoji_if_needed(command.NAME), 

86 "value": command.NODE, 

87 } 

88 for command in commands.values() 

89 ], 

90 ) 

91 ], 

92 partial(self.__handle_category_choice, commands), 

93 ), 

94 adhoc_session, 

95 ) 

96 

97 async def __handle_category_choice( 

98 self, 

99 commands: dict[str, Command], 

100 form_values: dict[str, str], 

101 session: "BaseSession[Any, Any]", 

102 jid: JID, 

103 ): 

104 command = commands[form_values["command"]] 

105 result = await self.__wrap_handler(command.run, session, jid) 

106 return result 

107 

108 async def __handle_result( 

109 self, 

110 session: Optional["BaseSession[Any, Any]"], 

111 result: CommandResponseType, 

112 adhoc_session: AdhocSessionType, 

113 ) -> AdhocSessionType: 

114 if isinstance(result, str) or result is None: 

115 adhoc_session["has_next"] = False 

116 adhoc_session["next"] = None 

117 adhoc_session["payload"] = None 

118 adhoc_session["notes"] = [("info", result or "Success!")] 

119 return adhoc_session 

120 

121 if isinstance(result, Form): 

122 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result) 

123 adhoc_session["has_next"] = True 

124 adhoc_session["payload"] = result.get_xml() 

125 if result.timeout_handler is not None: 

126 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later( 

127 self.FORM_TIMEOUT, 

128 partial( 

129 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"] 

130 ), 

131 ) 

132 return adhoc_session 

133 

134 if isinstance(result, Confirmation): 

135 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

136 adhoc_session["has_next"] = True 

137 adhoc_session["payload"] = result.get_form() 

138 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

139 return adhoc_session 

140 

141 if isinstance(result, TableResult): 

142 adhoc_session["next"] = None 

143 adhoc_session["has_next"] = False 

144 adhoc_session["payload"] = result.get_xml() 

145 return adhoc_session 

146 

147 raise XMPPError("internal-server-error", text="OOPS!") 

148 

149 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None: 

150 try: 

151 del self.xmpp.plugin["xep_0050"].sessions[session_id] 

152 except KeyError: 

153 log.error("Timeout but session could not be found: %s", session_id) 

154 handler() 

155 

156 @staticmethod 

157 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore 

158 try: 

159 if inspect.iscoroutinefunction(f): 

160 return await f(*a, **k) 

161 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

162 return await f(*a, **k) 

163 else: 

164 return f(*a, **k) 

165 except XMPPError: 

166 raise 

167 except Exception as e: 

168 log.debug("Exception in %s", f, exc_info=e) 

169 raise XMPPError("internal-server-error", text=str(e)) 

170 

171 async def __wrap_form_handler( 

172 self, 

173 session: Optional["BaseSession[Any, Any]"], 

174 result: Form, 

175 form: SlixForm, 

176 adhoc_session: AdhocSessionType, 

177 ) -> AdhocSessionType: 

178 timer = self.__timeouts.pop(adhoc_session["id"], None) 

179 if timer is not None: 

180 print("canceled", adhoc_session["id"]) 

181 timer.cancel() 

182 form_values = result.get_values(form) 

183 new_result = await self.__wrap_handler( 

184 result.handler, 

185 form_values, 

186 session, 

187 adhoc_session["from"], 

188 *result.handler_args, 

189 **result.handler_kwargs, 

190 ) 

191 

192 return await self.__handle_result(session, new_result, adhoc_session) 

193 

194 async def __wrap_confirmation( 

195 self, 

196 session: Optional["BaseSession[Any, Any]"], 

197 confirmation: Confirmation, 

198 form: SlixForm, 

199 adhoc_session: AdhocSessionType, 

200 ) -> AdhocSessionType: 

201 if form.get_values().get("confirm"): # type: ignore[no-untyped-call] 

202 result = await self.__wrap_handler( 

203 confirmation.handler, 

204 session, 

205 adhoc_session["from"], 

206 *confirmation.handler_args, 

207 **confirmation.handler_kwargs, 

208 ) 

209 if confirmation.success: 

210 result = confirmation.success 

211 else: 

212 result = "You canceled the operation" 

213 

214 return await self.__handle_result(session, result, adhoc_session) 

215 

216 def register(self, command: Command, jid: Optional[JID] = None) -> None: 

217 """ 

218 Register a command as a adhoc command. 

219 

220 this does not need to be called manually, ``BaseGateway`` takes care of 

221 that. 

222 

223 :param command: 

224 :param jid: 

225 """ 

226 if jid is None: 

227 jid = self.xmpp.boundjid 

228 elif not isinstance(jid, JID): 

229 jid = JID(jid) 

230 

231 if (category := command.CATEGORY) is None: 

232 if command.NODE in self._commands: 

233 raise RuntimeError( 

234 "There is already a command for the node '%s'", command.NODE 

235 ) 

236 self._commands[command.NODE] = command 

237 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

238 jid=jid, 

239 node=command.NODE, 

240 name=strip_leading_emoji_if_needed(command.NAME), 

241 handler=partial(self.__wrap_initial_handler, command), 

242 ) 

243 else: 

244 if isinstance(category, str): 

245 category = CommandCategory(category, category) 

246 node = category.node 

247 name = category.name 

248 if node not in self._categories: 

249 self._categories[node] = list[Command]() 

250 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

251 jid=jid, 

252 node=node, 

253 name=strip_leading_emoji_if_needed(name), 

254 handler=partial(self.__handle_category_list, category), 

255 ) 

256 self._categories[node].append(command) 

257 

258 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

259 """ 

260 Get items for a disco query 

261 

262 :param jid: the entity that should return its items 

263 :param node: which command node is requested 

264 :param iq: the disco query IQ 

265 :return: commands accessible to the given JID will be listed 

266 """ 

267 ifrom = iq.get_from() 

268 ifrom_str = str(ifrom) 

269 if ( 

270 not self.xmpp.jid_validator.match(ifrom_str) 

271 and ifrom_str not in config.ADMINS 

272 ): 

273 raise XMPPError( 

274 "forbidden", 

275 "You are not authorized to execute adhoc commands on this gateway. " 

276 "If this is unexpected, ask your administrator to verify that " 

277 "'user-jid-validator' is correctly set in slidge's configuration.", 

278 ) 

279 

280 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

281 log.debug("Static items: %r", all_items) 

282 if not all_items: 

283 return DiscoItems() 

284 

285 session = self.xmpp.get_session_from_jid(ifrom) 

286 

287 filtered_items = DiscoItems() 

288 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

289 for item in all_items: 

290 authorized = True 

291 if item["node"] in self._categories: 

292 for command in self._categories[item["node"]]: 

293 try: 

294 command.raise_if_not_authorized( 

295 ifrom, fetch_session=False, session=session 

296 ) 

297 except XMPPError: 

298 authorized = False 

299 else: 

300 authorized = True 

301 break 

302 else: 

303 try: 

304 self._commands[item["node"]].raise_if_not_authorized( 

305 ifrom, fetch_session=False, session=session 

306 ) 

307 except XMPPError: 

308 authorized = False 

309 

310 if authorized: 

311 filtered_items.append(item) 

312 

313 return filtered_items 

314 

315 

316def strip_leading_emoji_if_needed(text: str) -> str: 

317 if config.STRIP_LEADING_EMOJI_ADHOC: 

318 return strip_leading_emoji(text) 

319 return text 

320 

321 

322log = logging.getLogger(__name__)