Coverage for slidge / command / base.py: 93%

206 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-02-15 09:02 +0000

1from abc import ABC 

2from collections.abc import Awaitable, Callable, Iterable, Sequence 

3from dataclasses import dataclass, field 

4from enum import Enum 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Optional, 

9 TypedDict, 

10 Union, 

11) 

12 

13from slixmpp import JID 

14from slixmpp.exceptions import XMPPError 

15from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

16from slixmpp.plugins.xep_0004 import FormField as SlixFormField 

17from slixmpp.types import JidStr 

18 

19from ..core import config 

20from ..util.types import AnyBaseSession, FieldType 

21 

22NODE_PREFIX = "https://slidge.im/command/core/" 

23 

24if TYPE_CHECKING: 

25 from ..core.gateway import BaseGateway 

26 from ..core.session import BaseSession 

27 from .categories import CommandCategory 

28 

29 

30HandlerType = ( 

31 Callable[[AnyBaseSession, JID], "CommandResponseType"] 

32 | Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]] 

33) 

34 

35FormValues = dict[str, str | JID | bool] 

36 

37 

38FormHandlerType = Callable[ 

39 [FormValues, AnyBaseSession, JID], 

40 Awaitable["CommandResponseType"], 

41] 

42 

43ConfirmationHandlerType = Callable[ 

44 [AnyBaseSession | None, JID], Awaitable["CommandResponseType"] 

45] 

46 

47 

48@dataclass 

49class TableResult: 

50 """ 

51 Structured data as the result of a command 

52 """ 

53 

54 fields: Sequence["FormField"] 

55 """ 

56 The 'columns names' of the table. 

57 """ 

58 items: Sequence[dict[str, str | JID]] 

59 """ 

60 The rows of the table. Each row is a dict where keys are the fields ``var`` 

61 attribute. 

62 """ 

63 description: str 

64 """ 

65 A description of the content of the table. 

66 """ 

67 

68 jids_are_mucs: bool = False 

69 

70 def get_xml(self) -> SlixForm: 

71 """ 

72 Get a slixmpp "form" (with <reported> header)to represent the data 

73 

74 :return: some XML 

75 """ 

76 form = SlixForm() # type: ignore[no-untyped-call] 

77 form["type"] = "result" 

78 form["title"] = self.description 

79 for f in self.fields: 

80 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call] 

81 for item in self.items: 

82 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call] 

83 return form 

84 

85 

86@dataclass 

87class SearchResult(TableResult): 

88 """ 

89 Results of the search command (search for contacts via Jabber Search) 

90 

91 Return type of :meth:`BaseSession.search`. 

92 """ 

93 

94 description: str = "Contact search results" 

95 

96 

97@dataclass 

98class Confirmation: 

99 """ 

100 A confirmation 'dialog' 

101 """ 

102 

103 prompt: str 

104 """ 

105 The text presented to the command triggering user 

106 """ 

107 handler: ConfirmationHandlerType 

108 """ 

109 An async function that should return a ResponseType 

110 """ 

111 success: str | None = None 

112 """ 

113 Text in case of success, used if handler does not return anything 

114 """ 

115 handler_args: Iterable[Any] = field(default_factory=list) 

116 """ 

117 arguments passed to the handler 

118 """ 

119 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

120 """ 

121 keyword arguments passed to the handler 

122 """ 

123 

124 def get_form(self) -> SlixForm: 

125 """ 

126 Get the slixmpp form 

127 

128 :return: some xml 

129 """ 

130 form = SlixForm() # type: ignore[no-untyped-call] 

131 form["type"] = "form" 

132 form["title"] = self.prompt 

133 form.append( 

134 FormField( 

135 "confirm", type="boolean", value="true", label="Confirm" 

136 ).get_xml() 

137 ) 

138 return form 

139 

140 

141@dataclass 

142class Form: 

143 """ 

144 A form, to request user input 

145 """ 

146 

147 title: str 

148 instructions: str 

149 fields: Sequence["FormField"] 

150 handler: FormHandlerType 

151 handler_args: Iterable[Any] = field(default_factory=list) 

152 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

153 timeout_handler: Callable[[], None] | None = None 

154 

155 def get_values( 

156 self, slix_form: SlixForm 

157 ) -> dict[str, list[str] | list[JID] | str | JID | bool | None]: 

158 """ 

159 Parse form submission 

160 

161 :param slix_form: the xml received as the submission of a form 

162 :return: A dict where keys=field.var and values are either strings 

163 or JIDs (if field.type=jid-single) 

164 """ 

165 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call] 

166 values = {} 

167 for f in self.fields: 

168 values[f.var] = f.validate(str_values.get(f.var)) 

169 return values 

170 

171 def get_xml(self) -> SlixForm: 

172 """ 

173 Get the slixmpp "form" 

174 

175 :return: some XML 

176 """ 

177 form = SlixForm() # type: ignore[no-untyped-call] 

178 form["type"] = "form" 

179 form["title"] = self.title 

180 form["instructions"] = self.instructions 

181 for fi in self.fields: 

182 form.append(fi.get_xml()) 

183 return form 

184 

185 

186class CommandAccess(int, Enum): 

187 """ 

188 Defines who can access a given Command 

189 """ 

190 

191 ADMIN_ONLY = 0 

192 USER = 1 

193 USER_LOGGED = 2 

194 USER_NON_LOGGED = 3 

195 NON_USER = 4 

196 ANY = 5 

197 

198 

199class Option(TypedDict): 

200 """ 

201 Options to be used for ``FormField``s of type ``list-*`` 

202 """ 

203 

204 label: str 

205 value: str 

206 

207 

208# TODO: support forms validation XEP-0122 

209@dataclass 

210class FormField: 

211 """ 

212 Represents a field of the form that a user will see when registering to the gateway 

213 via their XMPP client. 

214 """ 

215 

216 var: str = "" 

217 """ 

218 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

219 """ 

220 label: str | None = None 

221 """Description of the field that the user will see""" 

222 required: bool = False 

223 """Whether this field is mandatory or not""" 

224 private: bool = False 

225 """ 

226 For sensitive info that should not be displayed on screen while the user types. 

227 Forces field_type to "text-private" 

228 """ 

229 type: FieldType = "text-single" 

230 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

231 value: str = "" 

232 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

233 options: list[Option] | None = None 

234 

235 image_url: str | None = None 

236 """An image associated to this field, eg, a QR code""" 

237 

238 def __post_init__(self) -> None: 

239 if self.private: 

240 self.type = "text-private" 

241 

242 def __acceptable_options(self) -> list[str]: 

243 if self.options is None: 

244 raise RuntimeError 

245 return [x["value"] for x in self.options] 

246 

247 def validate( 

248 self, value: str | list[str] | None 

249 ) -> list[str] | list[JID] | str | JID | bool | None: 

250 """ 

251 Raise appropriate XMPPError if a given value is valid for this field 

252 

253 :param value: The value to test 

254 :return: The same value OR a JID if ``self.type=jid-single`` 

255 """ 

256 if isinstance(value, list) and not self.type.endswith("multi"): 

257 raise XMPPError("not-acceptable", "A single value was expected") 

258 

259 if self.type in ("list-multi", "jid-multi"): 

260 if not value: 

261 value = [] 

262 if isinstance(value, list): 

263 return self.__validate_list_multi(value) 

264 else: 

265 raise XMPPError("not-acceptable", "Multiple values was expected") 

266 

267 assert isinstance(value, (str, bool, JID)) or value is None 

268 

269 if self.required and value is None: 

270 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

271 

272 if value is None: 

273 return None 

274 

275 if self.type == "jid-single": 

276 try: 

277 return JID(value) 

278 except ValueError: 

279 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

280 

281 elif self.type == "list-single": 

282 if value not in self.__acceptable_options(): 

283 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

284 

285 elif self.type == "boolean": 

286 return value.lower() in ("1", "true") if isinstance(value, str) else value 

287 

288 return value 

289 

290 def __validate_list_multi(self, value: list[str]) -> list[str] | list[JID]: 

291 # COMPAT: all the "if v" and "if not v" are workarounds for https://codeberg.org/slidge/slidge/issues/43 

292 # They should be reverted once the bug is fixed upstream, cf https://soprani.ca/todo/390 

293 for v in value: 

294 if v not in self.__acceptable_options(): 

295 if not v: 

296 continue 

297 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

298 if self.type == "list-multi": 

299 return [v for v in value if v] 

300 return [JID(v) for v in value if v] 

301 

302 def get_xml(self) -> SlixFormField: 

303 """ 

304 Get the field in slixmpp format 

305 

306 :return: some XML 

307 """ 

308 f = SlixFormField() 

309 f["var"] = self.var 

310 f["label"] = self.label 

311 f["required"] = self.required 

312 f["type"] = self.type 

313 if self.options: 

314 for o in self.options: 

315 f.add_option(**o) # type: ignore[no-untyped-call] 

316 f["value"] = self.value 

317 if self.image_url: 

318 f["media"].add_uri(self.image_url, itype="image/png") 

319 return f 

320 

321 

322CommandResponseType = TableResult | Confirmation | Form | str | None 

323 

324 

325class Command(ABC): 

326 """ 

327 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

328 """ 

329 

330 NAME: str = NotImplemented 

331 """ 

332 Friendly name of the command, eg: "do something with stuff" 

333 """ 

334 HELP: str = NotImplemented 

335 """ 

336 Long description of what the command does 

337 """ 

338 NODE: str = NotImplemented 

339 """ 

340 Name of the node used for ad-hoc commands 

341 """ 

342 CHAT_COMMAND: str = NotImplemented 

343 """ 

344 Text to send to the gateway to trigger the command via a message 

345 """ 

346 

347 ACCESS: "CommandAccess" = NotImplemented 

348 """ 

349 Who can use this command 

350 """ 

351 

352 CATEGORY: Union[str, "CommandCategory"] | None = None 

353 """ 

354 If used, the command will be under this top-level category. 

355 Use the same string for several commands to group them. 

356 This hierarchy only used for the adhoc interface, not the chat command 

357 interface. 

358 """ 

359 

360 subclasses = list[type["Command"]]() 

361 

362 def __init__(self, xmpp: "BaseGateway") -> None: 

363 self.xmpp = xmpp 

364 

365 def __init_subclass__(cls, **kwargs: Any) -> None: 

366 # store subclasses so subclassing is enough for the command to be 

367 # picked up by slidge 

368 cls.subclasses.append(cls) 

369 

370 async def run( 

371 self, session: Optional["BaseSession[Any, Any]"], ifrom: JID, *args: str 

372 ) -> CommandResponseType: 

373 """ 

374 Entry point of the command 

375 

376 :param session: If triggered by a registered user, its slidge Session 

377 :param ifrom: JID of the command-triggering entity 

378 :param args: When triggered via chatbot type message, additional words 

379 after the CHAT_COMMAND string was passed 

380 

381 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

382 """ 

383 raise XMPPError("feature-not-implemented") 

384 

385 def _get_session(self, jid: JID) -> Optional["BaseSession[Any, Any]"]: 

386 return self.xmpp.get_session_from_jid(jid) 

387 

388 def __can_use_command(self, jid: JID): 

389 j = jid.bare 

390 return self.xmpp.jid_validator.match(j) or j in config.ADMINS 

391 

392 def raise_if_not_authorized( 

393 self, 

394 jid: JID, 

395 fetch_session: bool = True, 

396 session: Optional["BaseSession[Any, Any]"] = None, 

397 ) -> Optional["BaseSession[Any, Any]"]: 

398 """ 

399 Raise an appropriate error is jid is not authorized to use the command 

400 

401 :param jid: jid of the entity trying to access the command 

402 :param fetch_session: 

403 :param session: 

404 

405 :return:session of JID if it exists 

406 """ 

407 if not self.__can_use_command(jid): 

408 raise XMPPError( 

409 "bad-request", "Your JID is not allowed to use this gateway." 

410 ) 

411 if fetch_session: 

412 session = self._get_session(jid) 

413 

414 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

415 raise XMPPError("not-authorized") 

416 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

417 raise XMPPError( 

418 "bad-request", "This is only available for non-users. Unregister first." 

419 ) 

420 elif self.ACCESS == CommandAccess.USER and session is None: 

421 raise XMPPError( 

422 "forbidden", 

423 "This is only available for users that are registered to this gateway", 

424 ) 

425 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

426 if session is None or session.logged: 

427 raise XMPPError( 

428 "forbidden", 

429 ( 

430 "This is only available for users that are not logged to the" 

431 " legacy service" 

432 ), 

433 ) 

434 elif self.ACCESS == CommandAccess.USER_LOGGED: 

435 if session is None or not session.logged: 

436 raise XMPPError( 

437 "forbidden", 

438 ( 

439 "This is only available when you are logged in to the legacy" 

440 " service" 

441 ), 

442 ) 

443 return session 

444 

445 

446def is_admin(jid: JidStr) -> bool: 

447 return JID(jid).bare in config.ADMINS