Coverage for slidge/command/adhoc.py: 88%

158 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-26 19:34 +0000

1import asyncio 

2import functools 

3import inspect 

4import logging 

5from functools import partial 

6from typing import TYPE_CHECKING, Any, Callable, Optional, Union 

7 

8from slixmpp import JID, Iq 

9from slixmpp.exceptions import XMPPError 

10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

12 

13from ..core import config 

14from ..util.util import strip_leading_emoji 

15from . import Command, CommandResponseType, Confirmation, Form, TableResult 

16from .base import FormField 

17from .categories import CommandCategory 

18 

19if TYPE_CHECKING: 

20 from ..core.gateway import BaseGateway 

21 from ..core.session import BaseSession 

22 

23 

24AdhocSessionType = dict[str, Any] 

25 

26 

27class AdhocProvider: 

28 """ 

29 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

30 untyped dict values than slixmpp. 

31 """ 

32 

33 FORM_TIMEOUT = 120 # seconds 

34 

35 def __init__(self, xmpp: "BaseGateway") -> None: 

36 self.xmpp = xmpp 

37 self._commands = dict[str, Command]() 

38 self._categories = dict[str, list[Command]]() 

39 xmpp.plugin["xep_0030"].set_node_handler( 

40 "get_items", 

41 jid=xmpp.boundjid, 

42 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

43 handler=self.get_items, 

44 ) 

45 self.__timeouts: dict[str, asyncio.TimerHandle] = {} 

46 

47 async def __wrap_initial_handler( 

48 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType 

49 ) -> AdhocSessionType: 

50 ifrom = iq.get_from() 

51 session = command.raise_if_not_authorized(ifrom) 

52 result = await self.__wrap_handler(command.run, session, ifrom) 

53 return await self.__handle_result(session, result, adhoc_session) 

54 

55 async def __handle_category_list( 

56 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

57 ) -> AdhocSessionType: 

58 try: 

59 session = self.xmpp.get_session_from_stanza(iq) 

60 except XMPPError: 

61 session = None 

62 commands: dict[str, Command] = {} 

63 for command in self._categories[category.node]: 

64 try: 

65 command.raise_if_not_authorized(iq.get_from()) 

66 except XMPPError: 

67 continue 

68 commands[command.NODE] = command 

69 if len(commands) == 0: 

70 raise XMPPError( 

71 "not-authorized", "There is no command you can run in this category" 

72 ) 

73 return await self.__handle_result( 

74 session, 

75 Form( 

76 category.name, 

77 "", 

78 [ 

79 FormField( 

80 var="command", 

81 label="Command", 

82 type="list-single", 

83 options=[ 

84 { 

85 "label": strip_leading_emoji_if_needed(command.NAME), 

86 "value": command.NODE, 

87 } 

88 for command in commands.values() 

89 ], 

90 ) 

91 ], 

92 partial(self.__handle_category_choice, commands), 

93 ), 

94 adhoc_session, 

95 ) 

96 

97 async def __handle_category_choice( 

98 self, 

99 commands: dict[str, Command], 

100 form_values: dict[str, str], 

101 session: "BaseSession[Any, Any]", 

102 jid: JID, 

103 ): 

104 command = commands[form_values["command"]] 

105 result = await self.__wrap_handler(command.run, session, jid) 

106 return result 

107 

108 async def __handle_result( 

109 self, 

110 session: Optional["BaseSession[Any, Any]"], 

111 result: CommandResponseType, 

112 adhoc_session: AdhocSessionType, 

113 ) -> AdhocSessionType: 

114 if isinstance(result, str) or result is None: 

115 adhoc_session["has_next"] = False 

116 adhoc_session["next"] = None 

117 adhoc_session["payload"] = None 

118 adhoc_session["notes"] = [("info", result or "Success!")] 

119 return adhoc_session 

120 

121 if isinstance(result, Form): 

122 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result) 

123 adhoc_session["has_next"] = True 

124 adhoc_session["payload"] = result.get_xml() 

125 if result.timeout_handler is not None: 

126 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later( 

127 self.FORM_TIMEOUT, 

128 partial( 

129 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"] 

130 ), 

131 ) 

132 return adhoc_session 

133 

134 if isinstance(result, Confirmation): 

135 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

136 adhoc_session["has_next"] = True 

137 adhoc_session["payload"] = result.get_form() 

138 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

139 return adhoc_session 

140 

141 if isinstance(result, TableResult): 

142 adhoc_session["next"] = None 

143 adhoc_session["has_next"] = False 

144 adhoc_session["payload"] = result.get_xml() 

145 return adhoc_session 

146 

147 raise XMPPError("internal-server-error", text="OOPS!") 

148 

149 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None: 

150 try: 

151 del self.xmpp.plugin["xep_0050"].sessions[session_id] 

152 except KeyError: 

153 log.error("Timeout but session could not be found: %s", session_id) 

154 handler() 

155 

156 @staticmethod 

157 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore 

158 try: 

159 if inspect.iscoroutinefunction(f): 

160 return await f(*a, **k) 

161 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

162 return await f(*a, **k) 

163 else: 

164 return f(*a, **k) 

165 except Exception as e: 

166 log.debug("Exception in %s", f, exc_info=e) 

167 raise XMPPError("internal-server-error", text=str(e)) 

168 

169 async def __wrap_form_handler( 

170 self, 

171 session: Optional["BaseSession[Any, Any]"], 

172 result: Form, 

173 form: SlixForm, 

174 adhoc_session: AdhocSessionType, 

175 ) -> AdhocSessionType: 

176 timer = self.__timeouts.pop(adhoc_session["id"], None) 

177 if timer is not None: 

178 print("canceled", adhoc_session["id"]) 

179 timer.cancel() 

180 form_values = result.get_values(form) 

181 new_result = await self.__wrap_handler( 

182 result.handler, 

183 form_values, 

184 session, 

185 adhoc_session["from"], 

186 *result.handler_args, 

187 **result.handler_kwargs, 

188 ) 

189 

190 return await self.__handle_result(session, new_result, adhoc_session) 

191 

192 async def __wrap_confirmation( 

193 self, 

194 session: Optional["BaseSession[Any, Any]"], 

195 confirmation: Confirmation, 

196 form: SlixForm, 

197 adhoc_session: AdhocSessionType, 

198 ) -> AdhocSessionType: 

199 if form.get_values().get("confirm"): # type: ignore[no-untyped-call] 

200 result = await self.__wrap_handler( 

201 confirmation.handler, 

202 session, 

203 adhoc_session["from"], 

204 *confirmation.handler_args, 

205 **confirmation.handler_kwargs, 

206 ) 

207 if confirmation.success: 

208 result = confirmation.success 

209 else: 

210 result = "You canceled the operation" 

211 

212 return await self.__handle_result(session, result, adhoc_session) 

213 

214 def register(self, command: Command, jid: Optional[JID] = None) -> None: 

215 """ 

216 Register a command as a adhoc command. 

217 

218 this does not need to be called manually, ``BaseGateway`` takes care of 

219 that. 

220 

221 :param command: 

222 :param jid: 

223 """ 

224 if jid is None: 

225 jid = self.xmpp.boundjid 

226 elif not isinstance(jid, JID): 

227 jid = JID(jid) 

228 

229 if (category := command.CATEGORY) is None: 

230 if command.NODE in self._commands: 

231 raise RuntimeError( 

232 "There is already a command for the node '%s'", command.NODE 

233 ) 

234 self._commands[command.NODE] = command 

235 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

236 jid=jid, 

237 node=command.NODE, 

238 name=strip_leading_emoji_if_needed(command.NAME), 

239 handler=partial(self.__wrap_initial_handler, command), 

240 ) 

241 else: 

242 if isinstance(category, str): 

243 category = CommandCategory(category, category) 

244 node = category.node 

245 name = category.name 

246 if node not in self._categories: 

247 self._categories[node] = list[Command]() 

248 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

249 jid=jid, 

250 node=node, 

251 name=strip_leading_emoji_if_needed(name), 

252 handler=partial(self.__handle_category_list, category), 

253 ) 

254 self._categories[node].append(command) 

255 

256 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

257 """ 

258 Get items for a disco query 

259 

260 :param jid: the entity that should return its items 

261 :param node: which command node is requested 

262 :param iq: the disco query IQ 

263 :return: commands accessible to the given JID will be listed 

264 """ 

265 ifrom = iq.get_from() 

266 ifrom_str = str(ifrom) 

267 if ( 

268 not self.xmpp.jid_validator.match(ifrom_str) 

269 and ifrom_str not in config.ADMINS 

270 ): 

271 raise XMPPError( 

272 "forbidden", 

273 "You are not authorized to execute adhoc commands on this gateway. " 

274 "If this is unexpected, ask your administrator to verify that " 

275 "'user-jid-validator' is correctly set in slidge's configuration.", 

276 ) 

277 

278 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

279 log.debug("Static items: %r", all_items) 

280 if not all_items: 

281 return DiscoItems() 

282 

283 session = self.xmpp.get_session_from_jid(ifrom) 

284 

285 filtered_items = DiscoItems() 

286 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

287 for item in all_items: 

288 authorized = True 

289 if item["node"] in self._categories: 

290 for command in self._categories[item["node"]]: 

291 try: 

292 command.raise_if_not_authorized( 

293 ifrom, fetch_session=False, session=session 

294 ) 

295 except XMPPError: 

296 authorized = False 

297 else: 

298 authorized = True 

299 break 

300 else: 

301 try: 

302 self._commands[item["node"]].raise_if_not_authorized( 

303 ifrom, fetch_session=False, session=session 

304 ) 

305 except XMPPError: 

306 authorized = False 

307 

308 if authorized: 

309 filtered_items.append(item) 

310 

311 return filtered_items 

312 

313 

314def strip_leading_emoji_if_needed(text: str) -> str: 

315 if config.STRIP_LEADING_EMOJI_ADHOC: 

316 return strip_leading_emoji(text) 

317 return text 

318 

319 

320log = logging.getLogger(__name__)