Coverage for slidge/command/adhoc.py: 92%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import asyncio 

2import functools 

3import logging 

4from functools import partial 

5from typing import TYPE_CHECKING, Any, Callable, Optional, Union 

6 

7from slixmpp import JID, Iq # type: ignore[attr-defined] 

8from slixmpp.exceptions import XMPPError 

9from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

10from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

11 

12from ..core import config 

13from ..util.util import strip_leading_emoji 

14from . import Command, CommandResponseType, Confirmation, Form, TableResult 

15from .base import FormField 

16from .categories import CommandCategory 

17 

18if TYPE_CHECKING: 

19 from ..core.gateway import BaseGateway 

20 from ..core.session import BaseSession 

21 

22 

23AdhocSessionType = dict[str, Any] 

24 

25 

26class AdhocProvider: 

27 """ 

28 A slixmpp-like plugin to handle adhoc commands, with less boilerplate and 

29 untyped dict values than slixmpp. 

30 """ 

31 

32 def __init__(self, xmpp: "BaseGateway") -> None: 

33 self.xmpp = xmpp 

34 self._commands = dict[str, Command]() 

35 self._categories = dict[str, list[Command]]() 

36 xmpp.plugin["xep_0030"].set_node_handler( 

37 "get_items", 

38 jid=xmpp.boundjid, 

39 node=self.xmpp.plugin["xep_0050"].stanza.Command.namespace, 

40 handler=self.get_items, 

41 ) 

42 

43 async def __wrap_initial_handler( 

44 self, command: Command, iq: Iq, adhoc_session: AdhocSessionType 

45 ) -> AdhocSessionType: 

46 ifrom = iq.get_from() 

47 session = command.raise_if_not_authorized(ifrom) 

48 result = await self.__wrap_handler(command.run, session, ifrom) 

49 return await self.__handle_result(session, result, adhoc_session) 

50 

51 async def __handle_category_list( 

52 self, category: CommandCategory, iq: Iq, adhoc_session: AdhocSessionType 

53 ) -> AdhocSessionType: 

54 try: 

55 session = self.xmpp.get_session_from_stanza(iq) 

56 except XMPPError: 

57 session = None 

58 commands: dict[str, Command] = {} 

59 for command in self._categories[category.node]: 

60 try: 

61 command.raise_if_not_authorized(iq.get_from()) 

62 except XMPPError: 

63 continue 

64 commands[command.NODE] = command 

65 if len(commands) == 0: 

66 raise XMPPError( 

67 "not-authorized", "There is no command you can run in this category" 

68 ) 

69 return await self.__handle_result( 

70 session, 

71 Form( 

72 category.name, 

73 "", 

74 [ 

75 FormField( 

76 var="command", 

77 label="Command", 

78 type="list-single", 

79 options=[ 

80 { 

81 "label": strip_leading_emoji_if_needed(command.NAME), 

82 "value": command.NODE, 

83 } 

84 for command in commands.values() 

85 ], 

86 ) 

87 ], 

88 partial(self.__handle_category_choice, commands), 

89 ), 

90 adhoc_session, 

91 ) 

92 

93 async def __handle_category_choice( 

94 self, 

95 commands: dict[str, Command], 

96 form_values: dict[str, str], 

97 session: "BaseSession[Any, Any]", 

98 jid: JID, 

99 ): 

100 command = commands[form_values["command"]] 

101 result = await self.__wrap_handler(command.run, session, jid) 

102 return result 

103 

104 async def __handle_result( 

105 self, 

106 session: Optional["BaseSession[Any, Any]"], 

107 result: CommandResponseType, 

108 adhoc_session: AdhocSessionType, 

109 ) -> AdhocSessionType: 

110 if isinstance(result, str) or result is None: 

111 adhoc_session["has_next"] = False 

112 adhoc_session["next"] = None 

113 adhoc_session["payload"] = None 

114 adhoc_session["notes"] = [("info", result or "Success!")] 

115 return adhoc_session 

116 

117 if isinstance(result, Form): 

118 adhoc_session["next"] = partial(self.__wrap_form_handler, session, result) 

119 adhoc_session["has_next"] = True 

120 adhoc_session["payload"] = result.get_xml() 

121 return adhoc_session 

122 

123 if isinstance(result, Confirmation): 

124 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

125 adhoc_session["has_next"] = True 

126 adhoc_session["payload"] = result.get_form() 

127 adhoc_session["next"] = partial(self.__wrap_confirmation, session, result) 

128 return adhoc_session 

129 

130 if isinstance(result, TableResult): 

131 adhoc_session["next"] = None 

132 adhoc_session["has_next"] = False 

133 adhoc_session["payload"] = result.get_xml() 

134 return adhoc_session 

135 

136 raise XMPPError("internal-server-error", text="OOPS!") 

137 

138 @staticmethod 

139 async def __wrap_handler(f: Union[Callable, functools.partial], *a, **k): # type: ignore 

140 try: 

141 if asyncio.iscoroutinefunction(f): 

142 return await f(*a, **k) 

143 elif hasattr(f, "func") and asyncio.iscoroutinefunction(f.func): 

144 return await f(*a, **k) 

145 else: 

146 return f(*a, **k) 

147 except Exception as e: 

148 log.debug("Exception in %s", f, exc_info=e) 

149 raise XMPPError("internal-server-error", text=str(e)) 

150 

151 async def __wrap_form_handler( 

152 self, 

153 session: Optional["BaseSession[Any, Any]"], 

154 result: Form, 

155 form: SlixForm, 

156 adhoc_session: AdhocSessionType, 

157 ) -> AdhocSessionType: 

158 form_values = result.get_values(form) 

159 new_result = await self.__wrap_handler( 

160 result.handler, 

161 form_values, 

162 session, 

163 adhoc_session["from"], 

164 *result.handler_args, 

165 **result.handler_kwargs, 

166 ) 

167 

168 return await self.__handle_result(session, new_result, adhoc_session) 

169 

170 async def __wrap_confirmation( 

171 self, 

172 session: Optional["BaseSession[Any, Any]"], 

173 confirmation: Confirmation, 

174 form: SlixForm, 

175 adhoc_session: AdhocSessionType, 

176 ) -> AdhocSessionType: 

177 if form.get_values().get("confirm"): # type: ignore[no-untyped-call] 

178 result = await self.__wrap_handler( 

179 confirmation.handler, 

180 session, 

181 adhoc_session["from"], 

182 *confirmation.handler_args, 

183 **confirmation.handler_kwargs, 

184 ) 

185 if confirmation.success: 

186 result = confirmation.success 

187 else: 

188 result = "You canceled the operation" 

189 

190 return await self.__handle_result(session, result, adhoc_session) 

191 

192 def register(self, command: Command, jid: Optional[JID] = None) -> None: 

193 """ 

194 Register a command as a adhoc command. 

195 

196 this does not need to be called manually, ``BaseGateway`` takes care of 

197 that. 

198 

199 :param command: 

200 :param jid: 

201 """ 

202 if jid is None: 

203 jid = self.xmpp.boundjid 

204 elif not isinstance(jid, JID): 

205 jid = JID(jid) 

206 

207 if (category := command.CATEGORY) is None: 

208 if command.NODE in self._commands: 

209 raise RuntimeError( 

210 "There is already a command for the node '%s'", command.NODE 

211 ) 

212 self._commands[command.NODE] = command 

213 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

214 jid=jid, 

215 node=command.NODE, 

216 name=strip_leading_emoji_if_needed(command.NAME), 

217 handler=partial(self.__wrap_initial_handler, command), 

218 ) 

219 else: 

220 if isinstance(category, str): 

221 category = CommandCategory(category, category) 

222 node = category.node 

223 name = category.name 

224 if node not in self._categories: 

225 self._categories[node] = list[Command]() 

226 self.xmpp.plugin["xep_0050"].add_command( # type: ignore[no-untyped-call] 

227 jid=jid, 

228 node=node, 

229 name=strip_leading_emoji_if_needed(name), 

230 handler=partial(self.__handle_category_list, category), 

231 ) 

232 self._categories[node].append(command) 

233 

234 async def get_items(self, jid: JID, node: str, iq: Iq) -> DiscoItems: 

235 """ 

236 Get items for a disco query 

237 

238 :param jid: who is requesting the disco 

239 :param node: which command node is requested 

240 :param iq: the disco query IQ 

241 :return: commands accessible to the given JID will be listed 

242 """ 

243 all_items = self.xmpp.plugin["xep_0030"].static.get_items(jid, node, None, None) 

244 log.debug("Static items: %r", all_items) 

245 if not all_items: 

246 return DiscoItems() 

247 

248 ifrom = iq.get_from() 

249 

250 filtered_items = DiscoItems() 

251 filtered_items["node"] = self.xmpp.plugin["xep_0050"].stanza.Command.namespace 

252 for item in all_items: 

253 authorized = True 

254 if item["node"] in self._categories: 

255 for command in self._categories[item["node"]]: 

256 try: 

257 command.raise_if_not_authorized(ifrom) 

258 except XMPPError: 

259 authorized = False 

260 else: 

261 authorized = True 

262 break 

263 else: 

264 try: 

265 self._commands[item["node"]].raise_if_not_authorized(ifrom) 

266 except XMPPError: 

267 authorized = False 

268 

269 if authorized: 

270 filtered_items.append(item) 

271 

272 return filtered_items 

273 

274 

275def strip_leading_emoji_if_needed(text: str) -> str: 

276 if config.STRIP_LEADING_EMOJI_ADHOC: 

277 return strip_leading_emoji(text) 

278 return text 

279 

280 

281log = logging.getLogger(__name__)