Coverage for slidge/command/adhoc.py: 92%

143 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1import asyncio 

2import functools 

3import logging 

4from functools import partial 

5from typing import TYPE_CHECKING, Any, Callable, Optional, Union 

6 

7from slixmpp import JID, Iq 

8from slixmpp.exceptions import XMPPError 

9from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

10from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

11 

12from ..core import config 

13from ..util.util import strip_leading_emoji 

14from . import Command, CommandResponseType, Confirmation, Form, TableResult 

15from .base import FormField 

16from .categories import CommandCategory 

17 

18if TYPE_CHECKING: 

19 from ..core.gateway import BaseGateway 

20 from ..core.session import BaseSession 

21 

22 

23AdhocSessionType = dict[str, Any] 

24 

25 

26class AdhocProvider: 

27 """ 

28 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

29 untyped dict values than slixmpp. 

30 """ 

31 

32 def __init__(self, xmpp: "BaseGateway") -> None: 

33 self.xmpp = xmpp 

34 self._commands = dict[str, Command]() 

35 self._categories = dict[str, list[Command]]() 

36 xmpp.plugin["xep_0030"].set_node_handler( 

37 "get_items", 

38 jid=xmpp.boundjid, 

39 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

40 handler=self.get_items, 

41 ) 

42 

43 async def __wrap_initial_handler( 

44 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType 

45 ) -> AdhocSessionType: 

46 ifrom = iq.get_from() 

47 session = command.raise_if_not_authorized(ifrom) 

48 result = await self.__wrap_handler(command.run, session, ifrom) 

49 return await self.__handle_result(session, result, adhoc_session) 

50 

51 async def __handle_category_list( 

52 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

53 ) -> AdhocSessionType: 

54 try: 

55 session = self.xmpp.get_session_from_stanza(iq) 

56 except XMPPError: 

57 session = None 

58 commands: dict[str, Command] = {} 

59 for command in self._categories[category.node]: 

60 try: 

61 command.raise_if_not_authorized(iq.get_from()) 

62 except XMPPError: 

63 continue 

64 commands[command.NODE] = command 

65 if len(commands) == 0: 

66 raise XMPPError( 

67 "not-authorized", "There is no command you can run in this category" 

68 ) 

69 return await self.__handle_result( 

70 session, 

71 Form( 

72 category.name, 

73 "", 

74 [ 

75 FormField( 

76 var="command", 

77 label="Command", 

78 type="list-single", 

79 options=[ 

80 { 

81 "label": strip_leading_emoji_if_needed(command.NAME), 

82 "value": command.NODE, 

83 } 

84 for command in commands.values() 

85 ], 

86 ) 

87 ], 

88 partial(self.__handle_category_choice, commands), 

89 ), 

90 adhoc_session, 

91 ) 

92 

93 async def __handle_category_choice( 

94 self, 

95 commands: dict[str, Command], 

96 form_values: dict[str, str], 

97 session: "BaseSession[Any, Any]", 

98 jid: JID, 

99 ): 

100 command = commands[form_values["command"]] 

101 result = await self.__wrap_handler(command.run, session, jid) 

102 return result 

103 

104 async def __handle_result( 

105 self, 

106 session: Optional["BaseSession[Any, Any]"], 

107 result: CommandResponseType, 

108 adhoc_session: AdhocSessionType, 

109 ) -> AdhocSessionType: 

110 if isinstance(result, str) or result is None: 

111 adhoc_session["has_next"] = False 

112 adhoc_session["next"] = None 

113 adhoc_session["payload"] = None 

114 adhoc_session["notes"] = [("info", result or "Success!")] 

115 return adhoc_session 

116 

117 if isinstance(result, Form): 

118 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result) 

119 adhoc_session["has_next"] = True 

120 adhoc_session["payload"] = result.get_xml() 

121 return adhoc_session 

122 

123 if isinstance(result, Confirmation): 

124 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

125 adhoc_session["has_next"] = True 

126 adhoc_session["payload"] = result.get_form() 

127 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

128 return adhoc_session 

129 

130 if isinstance(result, TableResult): 

131 adhoc_session["next"] = None 

132 adhoc_session["has_next"] = False 

133 adhoc_session["payload"] = result.get_xml() 

134 return adhoc_session 

135 

136 raise XMPPError("internal-server-error", text="OOPS!") 

137 

138 @staticmethod 

139 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore 

140 try: 

141 if asyncio.iscoroutinefunction(f): 

142 return await f(*a, **k) 

143 elif hasattr(f, "func") and asyncio.iscoroutinefunction(f.func): 

144 return await f(*a, **k) 

145 else: 

146 return f(*a, **k) 

147 except Exception as e: 

148 log.debug("Exception in %s", f, exc_info=e) 

149 raise XMPPError("internal-server-error", text=str(e)) 

150 

151 async def __wrap_form_handler( 

152 self, 

153 session: Optional["BaseSession[Any, Any]"], 

154 result: Form, 

155 form: SlixForm, 

156 adhoc_session: AdhocSessionType, 

157 ) -> AdhocSessionType: 

158 form_values = result.get_values(form) 

159 new_result = await self.__wrap_handler( 

160 result.handler, 

161 form_values, 

162 session, 

163 adhoc_session["from"], 

164 *result.handler_args, 

165 **result.handler_kwargs, 

166 ) 

167 

168 return await self.__handle_result(session, new_result, adhoc_session) 

169 

170 async def __wrap_confirmation( 

171 self, 

172 session: Optional["BaseSession[Any, Any]"], 

173 confirmation: Confirmation, 

174 form: SlixForm, 

175 adhoc_session: AdhocSessionType, 

176 ) -> AdhocSessionType: 

177 if form.get_values().get("confirm"): # type: ignore[no-untyped-call] 

178 result = await self.__wrap_handler( 

179 confirmation.handler, 

180 session, 

181 adhoc_session["from"], 

182 *confirmation.handler_args, 

183 **confirmation.handler_kwargs, 

184 ) 

185 if confirmation.success: 

186 result = confirmation.success 

187 else: 

188 result = "You canceled the operation" 

189 

190 return await self.__handle_result(session, result, adhoc_session) 

191 

192 def register(self, command: Command, jid: Optional[JID] = None) -> None: 

193 """ 

194 Register a command as a adhoc command. 

195 

196 this does not need to be called manually, ``BaseGateway`` takes care of 

197 that. 

198 

199 :param command: 

200 :param jid: 

201 """ 

202 if jid is None: 

203 jid = self.xmpp.boundjid 

204 elif not isinstance(jid, JID): 

205 jid = JID(jid) 

206 

207 if (category := command.CATEGORY) is None: 

208 if command.NODE in self._commands: 

209 raise RuntimeError( 

210 "There is already a command for the node '%s'", command.NODE 

211 ) 

212 self._commands[command.NODE] = command 

213 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

214 jid=jid, 

215 node=command.NODE, 

216 name=strip_leading_emoji_if_needed(command.NAME), 

217 handler=partial(self.__wrap_initial_handler, command), 

218 ) 

219 else: 

220 if isinstance(category, str): 

221 category = CommandCategory(category, category) 

222 node = category.node 

223 name = category.name 

224 if node not in self._categories: 

225 self._categories[node] = list[Command]() 

226 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

227 jid=jid, 

228 node=node, 

229 name=strip_leading_emoji_if_needed(name), 

230 handler=partial(self.__handle_category_list, category), 

231 ) 

232 self._categories[node].append(command) 

233 

234 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

235 """ 

236 Get items for a disco query 

237 

238 :param jid: the entity that should return its items 

239 :param node: which command node is requested 

240 :param iq: the disco query IQ 

241 :return: commands accessible to the given JID will be listed 

242 """ 

243 ifrom = iq.get_from() 

244 ifrom_str = str(ifrom) 

245 if ( 

246 not self.xmpp.jid_validator.match(ifrom_str) 

247 and ifrom_str not in config.ADMINS 

248 ): 

249 raise XMPPError( 

250 "forbidden", 

251 "You are not authorized to execute adhoc commands on this gateway. " 

252 "If this is unexpected, ask your administrator to verify that " 

253 "'user-jid-validator' is correctly set in slidge's configuration.", 

254 ) 

255 

256 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

257 log.debug("Static items: %r", all_items) 

258 if not all_items: 

259 return DiscoItems() 

260 

261 session = self.xmpp.get_session_from_jid(ifrom) 

262 

263 filtered_items = DiscoItems() 

264 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

265 for item in all_items: 

266 authorized = True 

267 if item["node"] in self._categories: 

268 for command in self._categories[item["node"]]: 

269 try: 

270 command.raise_if_not_authorized( 

271 ifrom, fetch_session=False, session=session 

272 ) 

273 except XMPPError: 

274 authorized = False 

275 else: 

276 authorized = True 

277 break 

278 else: 

279 try: 

280 self._commands[item["node"]].raise_if_not_authorized( 

281 ifrom, fetch_session=False, session=session 

282 ) 

283 except XMPPError: 

284 authorized = False 

285 

286 if authorized: 

287 filtered_items.append(item) 

288 

289 return filtered_items 

290 

291 

292def strip_leading_emoji_if_needed(text: str) -> str: 

293 if config.STRIP_LEADING_EMOJI_ADHOC: 

294 return strip_leading_emoji(text) 

295 return text 

296 

297 

298log = logging.getLogger(__name__)