Coverage for slidge / command / adhoc.py: 85%

202 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-20 19:56 +0000

1import asyncio 

2import inspect 

3import logging 

4from collections.abc import Awaitable, Callable 

5from functools import partial 

6from typing import TYPE_CHECKING, Any, Generic, Optional, ParamSpec, TypeVar, cast 

7 

8from slixmpp import JID, Iq 

9from slixmpp.exceptions import XMPPError 

10from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

11from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

12from slixmpp.plugins.xep_0050.adhoc import CommandType 

13 

14from slidge.util.types import AnyContact, AnyMUC, AnyRecipient, AnySession 

15 

16from ..core import config 

17from ..util.util import strip_leading_emoji 

18from . import Command, CommandResponseType, Confirmation, Form, TableResult 

19from .base import ( 

20 CommandResponseRecipientType, 

21 CommandResponseSessionType, 

22 ContactCommand, 

23 FormField, 

24 MUCCommand, 

25) 

26from .categories import CommandCategory 

27 

28if TYPE_CHECKING: 

29 from ..core.gateway import BaseGateway 

30 from ..core.session import BaseSession 

31 

32 

33GatewayType = TypeVar("GatewayType", bound="BaseGateway[Any]") 

34AdhocSessionType = dict[str, Any] 

35T = TypeVar("T") 

36P = ParamSpec("P") 

37 

38 

39class AdhocProvider(Generic[GatewayType]): 

40 """ 

41 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

42 untyped dict values than slixmpp. 

43 """ 

44 

45 FORM_TIMEOUT = 120 # seconds 

46 xmpp: GatewayType 

47 

48 def __init__(self, xmpp: GatewayType) -> None: 

49 self.xmpp = xmpp 

50 self._commands = dict[str, Command[AnySession]]() 

51 self._categories = dict[str, list[Command[AnySession]]]() 

52 xmpp.plugin["xep_0030"].set_node_handler( 

53 "get_items", 

54 jid=xmpp.boundjid, 

55 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

56 handler=self.get_items, 

57 ) 

58 self.xmpp.plugin["xep_0050"].api.register(self.__get_command, "get_command") 

59 self.__timeouts: dict[str, asyncio.TimerHandle] = {} 

60 

61 async def __get_command( 

62 self, 

63 jid: JID | str | None = None, 

64 node: str | None = None, 

65 ifrom: JID | None = None, 

66 args: None = None, 

67 ) -> CommandType | None: 

68 if node is None: 

69 return None 

70 if jid is None: 

71 jid = self.xmpp.boundjid.bare 

72 if not isinstance(jid, JID): 

73 jid = JID(jid) 

74 if jid == self.xmpp.boundjid.bare: 

75 return self.xmpp.plugin["xep_0050"].commands.get((jid.full, node)) 

76 if ifrom is None: 

77 raise XMPPError("undefined-condition") 

78 session = self.xmpp.get_session_from_jid(ifrom) 

79 if session is None: 

80 raise XMPPError("subscription-required") 

81 recipient = await session.get_contact_or_group_or_participant(jid) 

82 if recipient is None: 

83 return None 

84 if recipient.is_participant: 

85 if recipient.contact is None: 

86 return None 

87 recipient = recipient.contact 

88 command = recipient.commands.get(node) 

89 if command is None: 

90 return None 

91 name = strip_leading_emoji_if_needed(command.NAME) 

92 handler = partial(self.__wrap_initial_handler, command, recipient=recipient) 

93 return name, handler, None, None # type:ignore 

94 

95 async def __wrap_initial_handler( 

96 self, 

97 command: Command[AnySession] 

98 | type[ContactCommand[AnyContact]] 

99 | type[MUCCommand[AnyMUC]], 

100 iq: Iq, 

101 adhoc_session: AdhocSessionType, 

102 recipient: AnyRecipient | None = None, 

103 ) -> AdhocSessionType: 

104 ifrom = iq.get_from() 

105 if recipient is None: 

106 cmd = cast(Command[AnySession], command) 

107 session = cmd.raise_if_not_authorized(ifrom) 

108 result1: CommandResponseSessionType[Any] = await self.__wrap_handler( 

109 cmd.run, session, ifrom 

110 ) 

111 return await self.__handle_result(session, result1, adhoc_session) 

112 else: 

113 cmd2 = cast( 

114 type[ContactCommand[AnyContact]] | type[MUCCommand[AnyMUC]], command 

115 ) 

116 result2: CommandResponseRecipientType[Any] = await self.__wrap_handler( 

117 cmd2.run, # type:ignore[arg-type] 

118 recipient, 

119 ) 

120 return await self.__handle_result( 

121 recipient.session, result2, adhoc_session, recipient 

122 ) 

123 

124 async def __handle_category_list( 

125 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

126 ) -> AdhocSessionType: 

127 try: 

128 session = self.xmpp.get_session_from_stanza(iq) 

129 except XMPPError: 

130 session = None 

131 commands: dict[str, Command[AnySession]] = {} 

132 for command in self._categories[category.node]: 

133 try: 

134 command.raise_if_not_authorized(iq.get_from()) 

135 except XMPPError: 

136 continue 

137 commands[command.NODE] = command 

138 if len(commands) == 0: 

139 raise XMPPError( 

140 "not-authorized", "There is no command you can run in this category" 

141 ) 

142 return await self.__handle_result( 

143 session, 

144 Form( 

145 category.name, 

146 "", 

147 [ 

148 FormField( 

149 var="command", 

150 label="Command", 

151 type="list-single", 

152 options=[ 

153 { 

154 "label": strip_leading_emoji_if_needed(command.NAME), 

155 "value": command.NODE, 

156 } 

157 for command in commands.values() 

158 ], 

159 ) 

160 ], 

161 partial(self.__handle_category_choice, commands), 

162 ), 

163 adhoc_session, 

164 ) 

165 

166 async def __handle_category_choice( 

167 self, 

168 commands: dict[str, Command[AnySession]], 

169 form_values: dict[str, str], 

170 session: "BaseSession[Any, Any]", 

171 jid: JID, 

172 ) -> CommandResponseSessionType[Any]: 

173 command = commands[form_values["command"]] 

174 result: CommandResponseSessionType[Any] = await self.__wrap_handler( 

175 command.run, session, jid 

176 ) 

177 return result 

178 

179 async def __handle_result( 

180 self, 

181 session: Optional["BaseSession[Any, Any]"], 

182 result: CommandResponseType, 

183 adhoc_session: AdhocSessionType, 

184 recipient: AnyRecipient | None = None, 

185 ) -> AdhocSessionType: 

186 if isinstance(result, str) or result is None: 

187 adhoc_session["has_next"] = False 

188 adhoc_session["next"] = None 

189 adhoc_session["payload"] = None 

190 adhoc_session["notes"] = [("info", result or "Success!")] 

191 return adhoc_session 

192 

193 if isinstance(result, Form): 

194 adhoc_session["next"] = partial( 

195 self.__wrap_form_handler, session, result, recipient 

196 ) 

197 adhoc_session["has_next"] = True 

198 adhoc_session["payload"] = result.get_xml() 

199 if result.timeout_handler is not None: 

200 self.__timeouts[adhoc_session["id"]] = self.xmpp.loop.call_later( 

201 self.FORM_TIMEOUT, 

202 partial( 

203 self.__wrap_timeout, result.timeout_handler, adhoc_session["id"] 

204 ), 

205 ) 

206 return adhoc_session 

207 

208 if isinstance(result, Confirmation): 

209 adhoc_session["next"] = partial( 

210 self.__wrap_confirmation, session, result, recipient 

211 ) 

212 adhoc_session["has_next"] = True 

213 adhoc_session["payload"] = result.get_form() 

214 return adhoc_session 

215 

216 if isinstance(result, TableResult): 

217 adhoc_session["next"] = None 

218 adhoc_session["has_next"] = False 

219 adhoc_session["payload"] = result.get_xml() 

220 return adhoc_session 

221 

222 raise XMPPError("internal-server-error", text="OOPS!") 

223 

224 def __wrap_timeout(self, handler: Callable[[], None], session_id: str) -> None: 

225 try: 

226 del self.xmpp.plugin["xep_0050"].sessions[session_id] 

227 except KeyError: 

228 log.error("Timeout but session could not be found: %s", session_id) 

229 handler() 

230 

231 @staticmethod 

232 async def __wrap_handler( 

233 f: Callable[P, Awaitable[T] | T], 

234 *a: P.args, 

235 **k: P.kwargs, 

236 ) -> T: 

237 try: 

238 if inspect.iscoroutinefunction(f): 

239 return await f(*a, **k) # type:ignore[no-any-return] 

240 elif hasattr(f, "func") and inspect.iscoroutinefunction(f.func): 

241 return await f(*a, **k) # type:ignore[misc,no-any-return] 

242 else: 

243 return f(*a, **k) # type:ignore[return-value] 

244 except XMPPError: 

245 raise 

246 except Exception as e: 

247 log.debug("Exception in %s", f, exc_info=e) 

248 raise XMPPError("internal-server-error", text=str(e)) 

249 

250 async def __wrap_form_handler( 

251 self, 

252 session: Optional["BaseSession[Any, Any]"], 

253 result: Form, 

254 recipient: AnyRecipient | None, 

255 form: SlixForm, 

256 adhoc_session: AdhocSessionType, 

257 ) -> AdhocSessionType: 

258 timer = self.__timeouts.pop(adhoc_session["id"], None) 

259 if timer is not None: 

260 print("canceled", adhoc_session["id"]) 

261 timer.cancel() 

262 form_values = result.get_values(form) 

263 if recipient is None: 

264 new_result = await self.__wrap_handler( 

265 result.handler, 

266 form_values, 

267 session, 

268 adhoc_session["from"], 

269 *result.handler_args, 

270 **result.handler_kwargs, 

271 ) 

272 else: 

273 new_result = await self.__wrap_handler( 

274 result.handler, 

275 recipient, 

276 form_values, 

277 *result.handler_args, 

278 **result.handler_kwargs, 

279 ) 

280 return await self.__handle_result(session, new_result, adhoc_session, recipient) 

281 

282 async def __wrap_confirmation( 

283 self, 

284 session: Optional["BaseSession[Any, Any]"], 

285 confirmation: Confirmation, 

286 recipient: AnyRecipient | None, 

287 form: SlixForm, 

288 adhoc_session: AdhocSessionType, 

289 ) -> AdhocSessionType: 

290 if form.get_values().get("confirm"): 

291 if recipient is None: 

292 result = await self.__wrap_handler( 

293 confirmation.handler, 

294 session, 

295 adhoc_session["from"], 

296 *confirmation.handler_args, 

297 **confirmation.handler_kwargs, 

298 ) 

299 if confirmation.success: 

300 result = confirmation.success 

301 else: 

302 result = await self.__wrap_handler( 

303 confirmation.handler, 

304 recipient, 

305 *confirmation.handler_args, 

306 **confirmation.handler_kwargs, 

307 ) 

308 else: 

309 result = "You canceled the operation" 

310 

311 return await self.__handle_result(session, result, adhoc_session, recipient) 

312 

313 def register(self, command: Command[AnySession], jid: JID | None = None) -> None: 

314 """ 

315 Register a command as a adhoc command. 

316 

317 this does not need to be called manually, ``BaseGateway`` takes care of 

318 that. 

319 

320 :param command: 

321 :param jid: 

322 """ 

323 if jid is None: 

324 jid = self.xmpp.boundjid 

325 elif not isinstance(jid, JID): 

326 jid = JID(jid) 

327 

328 if (category := command.CATEGORY) is None: 

329 if command.NODE in self._commands: 

330 raise RuntimeError( 

331 "There is already a command for the node '%s'", command.NODE 

332 ) 

333 self._commands[command.NODE] = command 

334 self.xmpp.plugin["xep_0050"].add_command( 

335 jid=jid, 

336 node=command.NODE, 

337 name=strip_leading_emoji_if_needed(command.NAME), 

338 handler=partial(self.__wrap_initial_handler, command), 

339 ) 

340 else: 

341 if isinstance(category, str): 

342 category = CommandCategory(category, category) 

343 node = category.node 

344 name = category.name 

345 if node not in self._categories: 

346 self._categories[node] = list[Command[AnySession]]() 

347 self.xmpp.plugin["xep_0050"].add_command( 

348 jid=jid, 

349 node=node, 

350 name=strip_leading_emoji_if_needed(name), 

351 handler=partial(self.__handle_category_list, category), 

352 ) 

353 self._categories[node].append(command) 

354 

355 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

356 """ 

357 Get items for a disco query 

358 

359 :param jid: the entity that should return its items 

360 :param node: which command node is requested 

361 :param iq: the disco query IQ 

362 :return: commands accessible to the given JID will be listed 

363 """ 

364 ifrom = iq.get_from() 

365 ifrom_str = str(ifrom) 

366 if ( 

367 not self.xmpp.jid_validator.match(ifrom_str) 

368 and ifrom_str not in config.ADMINS 

369 ): 

370 raise XMPPError( 

371 "forbidden", 

372 "You are not authorized to execute adhoc commands on this gateway. " 

373 "If this is unexpected, ask your administrator to verify that " 

374 "'user-jid-validator' is correctly set in slidge's configuration.", 

375 ) 

376 

377 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

378 log.debug("Static items: %r", all_items) 

379 if not all_items: 

380 return DiscoItems() 

381 

382 session = self.xmpp.get_session_from_jid(ifrom) 

383 

384 filtered_items = DiscoItems() 

385 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

386 for item in all_items: 

387 authorized = True 

388 if item["node"] in self._categories: 

389 for command in self._categories[item["node"]]: 

390 try: 

391 command.raise_if_not_authorized( 

392 ifrom, fetch_session=False, session=session 

393 ) 

394 except XMPPError: 

395 authorized = False 

396 else: 

397 authorized = True 

398 break 

399 else: 

400 try: 

401 self._commands[item["node"]].raise_if_not_authorized( 

402 ifrom, fetch_session=False, session=session 

403 ) 

404 except XMPPError: 

405 authorized = False 

406 

407 if authorized: 

408 filtered_items.append(item) 

409 

410 return filtered_items 

411 

412 

413def strip_leading_emoji_if_needed(text: str) -> str: 

414 if config.STRIP_LEADING_EMOJI_ADHOC: 

415 return strip_leading_emoji(text) 

416 return text 

417 

418 

419log = logging.getLogger(__name__)