Coverage for slidge / command / adhoc.py: 86%

196 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1import asyncio 

2import inspect 

3import logging 

4from collections.abc import Awaitable, Callable 

5from functools import partial 

6from typing import TYPE_CHECKING, Any, Optional, ParamSpec, TypeVar, cast 

7 

8from slixmpp import JID, Iq 

9from slixmpp.exceptions import XMPPError 

10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

12from slixmpp.plugins.xep_0050.adhoc import CommandType 

13 

14from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession 

15 

16from ..core import config 

17from ..util.util import strip_leading_emoji 

18from . import Command, CommandResponseType, Confirmation, Form, TableResult 

19from .base import ( 

20 CommandResponseRecipientType, 

21 CommandResponseSessionType, 

22 ContactCommand, 

23 FormField, 

24 MUCCommand, 

25) 

26from .categories import CommandCategory 

27 

28if TYPE_CHECKING: 

29 from ..core.gateway import BaseGateway 

30 from ..core.session import BaseSession 

31 

32 

33AdhocSessionType = dict[str, Any] 

34T = TypeVar("T") 

35P = ParamSpec("P") 

36 

37 

38class AdhocProvider: 

39 """ 

40 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

41 untyped dict values than slixmpp. 

42 """ 

43 

44 FORM_TIMEOUT = 120 # seconds 

45 

46 def __init__(self, xmpp: "BaseGateway") -> None: 

47 self.xmpp = xmpp 

48 self._commands = dict[str, Command[AnySession]]() 

49 self._categories = dict[str, list[Command[AnySession]]]() 

50 xmpp.plugin["xep_0030"].set_node_handler( 

51 "get_items", 

52 jid=xmpp.boundjid, 

53 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

54 handler=self.get_items, 

55 ) 

56 self.xmpp.plugin["xep_0050"].api.register(self.__get_command, "get_command") 

57 self.__timeouts: dict[str, asyncio.TimerHandle] = {} 

58 

59 async def __get_command( 

60 self, 

61 jid: JID | str | None = None, 

62 node: str | None = None, 

63 ifrom: JID | None = None, 

64 args: None = None, 

65 ) -> CommandType | None: 

66 if node is None: 

67 return None 

68 if jid is None: 

69 jid = self.xmpp.boundjid.bare 

70 if not isinstance(jid, JID): 

71 jid = JID(jid) 

72 if jid == self.xmpp.boundjid.bare: 

73 return self.xmpp.plugin["xep_0050"].commands.get((jid.full, node)) 

74 if ifrom is None: 

75 raise XMPPError("undefined-condition") 

76 session = self.xmpp.get_session_from_jid(ifrom) 

77 if session is None: 

78 raise XMPPError("subscription-required") 

79 recipient = await session.get_contact_or_group_or_participant(jid) 

80 if recipient is None or recipient.is_participant: 

81 return None 

82 command = recipient.commands.get(node) 

83 if command is None: 

84 return None 

85 name = strip_leading_emoji_if_needed(command.NAME) 

86 handler = partial(self.__wrap_initial_handler, command, recipient=recipient) 

87 return name, handler, None, None # type:ignore 

88 

89 async def __wrap_initial_handler( 

90 self, 

91 command: Command[AnySession] 

92 | type[ContactCommand[AnyContact]] 

93 | type[MUCCommand[AnyMUC]], 

94 iq: Iq, 

95 adhoc_session: AdhocSessionType, 

96 recipient: AnyRecipient | None = None, 

97 ) -> AdhocSessionType: 

98 ifrom = iq.get_from() 

99 if recipient is None: 

100 cmd = cast(Command[AnySession], command) 

101 session = cmd.raise_if_not_authorized(ifrom) 

102 result1: CommandResponseSessionType[Any] = await self.__wrap_handler( 

103 cmd.run, session, ifrom 

104 ) 

105 return await self.__handle_result(session, result1, adhoc_session) 

106 else: 

107 cmd2 = cast( 

108 type[ContactCommand[AnyContact]] | type[MUCCommand[AnyMUC]], command 

109 ) 

110 result2: CommandResponseRecipientType[Any] = await self.__wrap_handler( 

111 cmd2.run, # type:ignore[arg-type] 

112 recipient, 

113 ) 

114 return await self.__handle_result( 

115 recipient.session, result2, adhoc_session, recipient 

116 ) 

117 

118 async def __handle_category_list( 

119 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

120 ) -> AdhocSessionType: 

121 try: 

122 session = self.xmpp.get_session_from_stanza(iq) 

123 except XMPPError: 

124 session = None 

125 commands: dict[str, Command[AnySession]] = {} 

126 for command in self._categories[category.node]: 

127 try: 

128 command.raise_if_not_authorized(iq.get_from()) 

129 except XMPPError: 

130 continue 

131 commands[command.NODE] = command 

132 if len(commands) == 0: 

133 raise XMPPError( 

134 "not-authorized", "There is no command you can run in this category" 

135 ) 

136 return await self.__handle_result( 

137 session, 

138 Form( 

139 category.name, 

140 "", 

141 [ 

142 FormField( 

143 var="command", 

144 label="Command", 

145 type="list-single", 

146 options=[ 

147 { 

148 "label": strip_leading_emoji_if_needed(command.NAME), 

149 "value": command.NODE, 

150 } 

151 for command in commands.values() 

152 ], 

153 ) 

154 ], 

155 partial(self.__handle_category_choice, commands), 

156 ), 

157 adhoc_session, 

158 ) 

159 

160 async def __handle_category_choice( 

161 self, 

162 commands: dict[str, Command[AnySession]], 

163 form_values: dict[str, str], 

164 session: "BaseSession[Any, Any]", 

165 jid: JID, 

166 ) -> CommandResponseSessionType[Any]: 

167 command = commands[form_values["command"]] 

168 result: CommandResponseSessionType[Any] = await self.__wrap_handler( 

169 command.run, session, jid 

170 ) 

171 return result 

172 

173 async def __handle_result( 

174 self, 

175 session: Optional["BaseSession[Any, Any]"], 

176 result: CommandResponseType, 

177 adhoc_session: AdhocSessionType, 

178 recipient: AnyRecipient | None = None, 

179 ) -> AdhocSessionType: 

180 if isinstance(result, str) or result is None: 

181 adhoc_session["has_next"] = False 

182 adhoc_session["next"] = None 

183 adhoc_session["payload"] = None 

184 adhoc_session["notes"] = [("info", result or "Success!")] 

185 return adhoc_session 

186 

187 if isinstance(result, Form): 

188 adhoc_session["next"] = partial( 

189 self.__wrap_form_handler, session, result, recipient 

190 ) 

191 adhoc_session["has_next"] = True 

192 adhoc_session["payload"] = result.get_xml() 

193 if result.timeout_handler is not None: 

194 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later( 

195 self.FORM_TIMEOUT, 

196 partial( 

197 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"] 

198 ), 

199 ) 

200 return adhoc_session 

201 

202 if isinstance(result, Confirmation): 

203 adhoc_session["next"] = partial( 

204 self.__wrap_confirmation, session, result, recipient 

205 ) 

206 adhoc_session["has_next"] = True 

207 adhoc_session["payload"] = result.get_form() 

208 return adhoc_session 

209 

210 if isinstance(result, TableResult): 

211 adhoc_session["next"] = None 

212 adhoc_session["has_next"] = False 

213 adhoc_session["payload"] = result.get_xml() 

214 return adhoc_session 

215 

216 raise XMPPError("internal-server-error", text="OOPS!") 

217 

218 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None: 

219 try: 

220 del self.xmpp.plugin["xep_0050"].sessions[session_id] 

221 except KeyError: 

222 log.error("Timeout but session could not be found: %s", session_id) 

223 handler() 

224 

225 @staticmethod 

226 async def __wrap_handler( 

227 f: Callable[P, Awaitable[T] | T], 

228 *a: P.args, 

229 **k: P.kwargs, 

230 ) -> T: 

231 try: 

232 if inspect.iscoroutinefunction(f): 

233 return await f(*a, **k) # type:ignore[no-any-return] 

234 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

235 return await f(*a, **k) # type:ignore[misc,no-any-return] 

236 else: 

237 return f(*a, **k) # type:ignore[return-value] 

238 except XMPPError: 

239 raise 

240 except Exception as e: 

241 log.debug("Exception in %s", f, exc_info=e) 

242 raise XMPPError("internal-server-error", text=str(e)) 

243 

244 async def __wrap_form_handler( 

245 self, 

246 session: Optional["BaseSession[Any, Any]"], 

247 result: Form, 

248 recipient: AnyRecipient | None, 

249 form: SlixForm, 

250 adhoc_session: AdhocSessionType, 

251 ) -> AdhocSessionType: 

252 timer = self.__timeouts.pop(adhoc_session["id"], None) 

253 if timer is not None: 

254 print("canceled", adhoc_session["id"]) 

255 timer.cancel() 

256 form_values = result.get_values(form) 

257 if recipient is None: 

258 new_result = await self.__wrap_handler( 

259 result.handler, 

260 form_values, 

261 session, 

262 adhoc_session["from"], 

263 *result.handler_args, 

264 **result.handler_kwargs, 

265 ) 

266 else: 

267 new_result = await self.__wrap_handler( 

268 result.handler, 

269 recipient, 

270 form_values, 

271 *result.handler_args, 

272 **result.handler_kwargs, 

273 ) 

274 return await self.__handle_result(session, new_result, adhoc_session, recipient) 

275 

276 async def __wrap_confirmation( 

277 self, 

278 session: Optional["BaseSession[Any, Any]"], 

279 confirmation: Confirmation, 

280 recipient: AnyRecipient | None, 

281 form: SlixForm, 

282 adhoc_session: AdhocSessionType, 

283 ) -> AdhocSessionType: 

284 if form.get_values().get("confirm"): 

285 if recipient is None: 

286 result = await self.__wrap_handler( 

287 confirmation.handler, 

288 session, 

289 adhoc_session["from"], 

290 *confirmation.handler_args, 

291 **confirmation.handler_kwargs, 

292 ) 

293 if confirmation.success: 

294 result = confirmation.success 

295 else: 

296 result = await self.__wrap_handler( 

297 confirmation.handler, 

298 recipient, 

299 *confirmation.handler_args, 

300 **confirmation.handler_kwargs, 

301 ) 

302 else: 

303 result = "You canceled the operation" 

304 

305 return await self.__handle_result(session, result, adhoc_session, recipient) 

306 

307 def register(self, command: Command[AnySession], jid: JID | None = None) -> None: 

308 """ 

309 Register a command as a adhoc command. 

310 

311 this does not need to be called manually, ``BaseGateway`` takes care of 

312 that. 

313 

314 :param command: 

315 :param jid: 

316 """ 

317 if jid is None: 

318 jid = self.xmpp.boundjid 

319 elif not isinstance(jid, JID): 

320 jid = JID(jid) 

321 

322 if (category := command.CATEGORY) is None: 

323 if command.NODE in self._commands: 

324 raise RuntimeError( 

325 "There is already a command for the node '%s'", command.NODE 

326 ) 

327 self._commands[command.NODE] = command 

328 self.xmpp.plugin["xep_0050"].add_command( 

329 jid=jid, 

330 node=command.NODE, 

331 name=strip_leading_emoji_if_needed(command.NAME), 

332 handler=partial(self.__wrap_initial_handler, command), 

333 ) 

334 else: 

335 if isinstance(category, str): 

336 category = CommandCategory(category, category) 

337 node = category.node 

338 name = category.name 

339 if node not in self._categories: 

340 self._categories[node] = list[Command[AnySession]]() 

341 self.xmpp.plugin["xep_0050"].add_command( 

342 jid=jid, 

343 node=node, 

344 name=strip_leading_emoji_if_needed(name), 

345 handler=partial(self.__handle_category_list, category), 

346 ) 

347 self._categories[node].append(command) 

348 

349 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

350 """ 

351 Get items for a disco query 

352 

353 :param jid: the entity that should return its items 

354 :param node: which command node is requested 

355 :param iq: the disco query IQ 

356 :return: commands accessible to the given JID will be listed 

357 """ 

358 ifrom = iq.get_from() 

359 ifrom_str = str(ifrom) 

360 if ( 

361 not self.xmpp.jid_validator.match(ifrom_str) 

362 and ifrom_str not in config.ADMINS 

363 ): 

364 raise XMPPError( 

365 "forbidden", 

366 "You are not authorized to execute adhoc commands on this gateway. " 

367 "If this is unexpected, ask your administrator to verify that " 

368 "'user-jid-validator' is correctly set in slidge's configuration.", 

369 ) 

370 

371 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

372 log.debug("Static items: %r", all_items) 

373 if not all_items: 

374 return DiscoItems() 

375 

376 session = self.xmpp.get_session_from_jid(ifrom) 

377 

378 filtered_items = DiscoItems() 

379 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

380 for item in all_items: 

381 authorized = True 

382 if item["node"] in self._categories: 

383 for command in self._categories[item["node"]]: 

384 try: 

385 command.raise_if_not_authorized( 

386 ifrom, fetch_session=False, session=session 

387 ) 

388 except XMPPError: 

389 authorized = False 

390 else: 

391 authorized = True 

392 break 

393 else: 

394 try: 

395 self._commands[item["node"]].raise_if_not_authorized( 

396 ifrom, fetch_session=False, session=session 

397 ) 

398 except XMPPError: 

399 authorized = False 

400 

401 if authorized: 

402 filtered_items.append(item) 

403 

404 return filtered_items 

405 

406 

407def strip_leading_emoji_if_needed(text: str) -> str: 

408 if config.STRIP_LEADING_EMOJI_ADHOC: 

409 return strip_leading_emoji(text) 

410 return text 

411 

412 

413log = logging.getLogger(__name__)