Coverage for slidge/command/base.py: 93%

206 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-26 19:34 +0000

1from abc import ABC 

2from dataclasses import dataclass, field 

3from enum import Enum 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Awaitable, 

8 Callable, 

9 Iterable, 

10 Optional, 

11 Sequence, 

12 Type, 

13 TypedDict, 

14 Union, 

15) 

16 

17from slixmpp import JID 

18from slixmpp.exceptions import XMPPError 

19from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

20from slixmpp.plugins.xep_0004 import FormField as SlixFormField 

21from slixmpp.types import JidStr 

22 

23from ..core import config 

24from ..db.models import GatewayUser 

25from ..util.types import AnyBaseSession, FieldType 

26 

27NODE_PREFIX = "https://slidge.im/command/core/" 

28 

29if TYPE_CHECKING: 

30 from ..core.gateway import BaseGateway 

31 from ..core.session import BaseSession 

32 from .categories import CommandCategory 

33 

34 

35HandlerType = Union[ 

36 Callable[[AnyBaseSession, JID], "CommandResponseType"], 

37 Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]], 

38] 

39 

40FormValues = dict[str, Union[str, JID, bool]] 

41 

42 

43FormHandlerType = Callable[ 

44 [FormValues, AnyBaseSession, JID], 

45 Awaitable["CommandResponseType"], 

46] 

47 

48ConfirmationHandlerType = Callable[ 

49 [Optional[AnyBaseSession], JID], Awaitable["CommandResponseType"] 

50] 

51 

52 

53@dataclass 

54class TableResult: 

55 """ 

56 Structured data as the result of a command 

57 """ 

58 

59 fields: Sequence["FormField"] 

60 """ 

61 The 'columns names' of the table. 

62 """ 

63 items: Sequence[dict[str, Union[str, JID]]] 

64 """ 

65 The rows of the table. Each row is a dict where keys are the fields ``var`` 

66 attribute. 

67 """ 

68 description: str 

69 """ 

70 A description of the content of the table. 

71 """ 

72 

73 jids_are_mucs: bool = False 

74 

75 def get_xml(self) -> SlixForm: 

76 """ 

77 Get a slixmpp "form" (with <reported> header)to represent the data 

78 

79 :return: some XML 

80 """ 

81 form = SlixForm() # type: ignore[no-untyped-call] 

82 form["type"] = "result" 

83 form["title"] = self.description 

84 for f in self.fields: 

85 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call] 

86 for item in self.items: 

87 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call] 

88 return form 

89 

90 

91@dataclass 

92class SearchResult(TableResult): 

93 """ 

94 Results of the search command (search for contacts via Jabber Search) 

95 

96 Return type of :meth:`BaseSession.search`. 

97 """ 

98 

99 description: str = "Contact search results" 

100 

101 

102@dataclass 

103class Confirmation: 

104 """ 

105 A confirmation 'dialog' 

106 """ 

107 

108 prompt: str 

109 """ 

110 The text presented to the command triggering user 

111 """ 

112 handler: ConfirmationHandlerType 

113 """ 

114 An async function that should return a ResponseType 

115 """ 

116 success: Optional[str] = None 

117 """ 

118 Text in case of success, used if handler does not return anything 

119 """ 

120 handler_args: Iterable[Any] = field(default_factory=list) 

121 """ 

122 arguments passed to the handler 

123 """ 

124 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

125 """ 

126 keyword arguments passed to the handler 

127 """ 

128 

129 def get_form(self) -> SlixForm: 

130 """ 

131 Get the slixmpp form 

132 

133 :return: some xml 

134 """ 

135 form = SlixForm() # type: ignore[no-untyped-call] 

136 form["type"] = "form" 

137 form["title"] = self.prompt 

138 form.append( 

139 FormField( 

140 "confirm", type="boolean", value="true", label="Confirm" 

141 ).get_xml() 

142 ) 

143 return form 

144 

145 

146@dataclass 

147class Form: 

148 """ 

149 A form, to request user input 

150 """ 

151 

152 title: str 

153 instructions: str 

154 fields: Sequence["FormField"] 

155 handler: FormHandlerType 

156 handler_args: Iterable[Any] = field(default_factory=list) 

157 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

158 timeout_handler: Callable[[], None] | None = None 

159 

160 def get_values( 

161 self, slix_form: SlixForm 

162 ) -> dict[str, Union[list[str], list[JID], str, JID, bool, None]]: 

163 """ 

164 Parse form submission 

165 

166 :param slix_form: the xml received as the submission of a form 

167 :return: A dict where keys=field.var and values are either strings 

168 or JIDs (if field.type=jid-single) 

169 """ 

170 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call] 

171 values = {} 

172 for f in self.fields: 

173 values[f.var] = f.validate(str_values.get(f.var)) 

174 return values 

175 

176 def get_xml(self) -> SlixForm: 

177 """ 

178 Get the slixmpp "form" 

179 

180 :return: some XML 

181 """ 

182 form = SlixForm() # type: ignore[no-untyped-call] 

183 form["type"] = "form" 

184 form["title"] = self.title 

185 form["instructions"] = self.instructions 

186 for fi in self.fields: 

187 form.append(fi.get_xml()) 

188 return form 

189 

190 

191class CommandAccess(int, Enum): 

192 """ 

193 Defines who can access a given Command 

194 """ 

195 

196 ADMIN_ONLY = 0 

197 USER = 1 

198 USER_LOGGED = 2 

199 USER_NON_LOGGED = 3 

200 NON_USER = 4 

201 ANY = 5 

202 

203 

204class Option(TypedDict): 

205 """ 

206 Options to be used for ``FormField``s of type ``list-*`` 

207 """ 

208 

209 label: str 

210 value: str 

211 

212 

213# TODO: support forms validation XEP-0122 

214@dataclass 

215class FormField: 

216 """ 

217 Represents a field of the form that a user will see when registering to the gateway 

218 via their XMPP client. 

219 """ 

220 

221 var: str = "" 

222 """ 

223 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

224 """ 

225 label: Optional[str] = None 

226 """Description of the field that the user will see""" 

227 required: bool = False 

228 """Whether this field is mandatory or not""" 

229 private: bool = False 

230 """ 

231 For sensitive info that should not be displayed on screen while the user types. 

232 Forces field_type to "text-private" 

233 """ 

234 type: FieldType = "text-single" 

235 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

236 value: str = "" 

237 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

238 options: Optional[list[Option]] = None 

239 

240 image_url: Optional[str] = None 

241 """An image associated to this field, eg, a QR code""" 

242 

243 def __post_init__(self) -> None: 

244 if self.private: 

245 self.type = "text-private" 

246 

247 def __acceptable_options(self) -> list[str]: 

248 if self.options is None: 

249 raise RuntimeError 

250 return [x["value"] for x in self.options] 

251 

252 def validate( 

253 self, value: Optional[Union[str, list[str]]] 

254 ) -> Union[list[str], list[JID], str, JID, bool, None]: 

255 """ 

256 Raise appropriate XMPPError if a given value is valid for this field 

257 

258 :param value: The value to test 

259 :return: The same value OR a JID if ``self.type=jid-single`` 

260 """ 

261 if isinstance(value, list) and not self.type.endswith("multi"): 

262 raise XMPPError("not-acceptable", "A single value was expected") 

263 

264 if self.type in ("list-multi", "jid-multi"): 

265 if not value: 

266 value = [] 

267 if isinstance(value, list): 

268 return self.__validate_list_multi(value) 

269 else: 

270 raise XMPPError("not-acceptable", "Multiple values was expected") 

271 

272 assert isinstance(value, (str, bool, JID)) or value is None 

273 

274 if self.required and value is None: 

275 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

276 

277 if value is None: 

278 return None 

279 

280 if self.type == "jid-single": 

281 try: 

282 return JID(value) 

283 except ValueError: 

284 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

285 

286 elif self.type == "list-single": 

287 if value not in self.__acceptable_options(): 

288 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

289 

290 elif self.type == "boolean": 

291 return value.lower() in ("1", "true") if isinstance(value, str) else value 

292 

293 return value 

294 

295 def __validate_list_multi(self, value: list[str]) -> Union[list[str], list[JID]]: 

296 # COMPAT: all the "if v" and "if not v" are workarounds for https://codeberg.org/slidge/slidge/issues/43 

297 # They should be reverted once the bug is fixed upstream, cf https://soprani.ca/todo/390 

298 for v in value: 

299 if v not in self.__acceptable_options(): 

300 if not v: 

301 continue 

302 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

303 if self.type == "list-multi": 

304 return [v for v in value if v] 

305 return [JID(v) for v in value if v] 

306 

307 def get_xml(self) -> SlixFormField: 

308 """ 

309 Get the field in slixmpp format 

310 

311 :return: some XML 

312 """ 

313 f = SlixFormField() 

314 f["var"] = self.var 

315 f["label"] = self.label 

316 f["required"] = self.required 

317 f["type"] = self.type 

318 if self.options: 

319 for o in self.options: 

320 f.add_option(**o) # type: ignore[no-untyped-call] 

321 f["value"] = self.value 

322 if self.image_url: 

323 f["media"].add_uri(self.image_url, itype="image/png") 

324 return f 

325 

326 

327CommandResponseType = Union[TableResult, Confirmation, Form, str, None] 

328 

329 

330class Command(ABC): 

331 """ 

332 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

333 """ 

334 

335 NAME: str = NotImplemented 

336 """ 

337 Friendly name of the command, eg: "do something with stuff" 

338 """ 

339 HELP: str = NotImplemented 

340 """ 

341 Long description of what the command does 

342 """ 

343 NODE: str = NotImplemented 

344 """ 

345 Name of the node used for ad-hoc commands 

346 """ 

347 CHAT_COMMAND: str = NotImplemented 

348 """ 

349 Text to send to the gateway to trigger the command via a message 

350 """ 

351 

352 ACCESS: "CommandAccess" = NotImplemented 

353 """ 

354 Who can use this command 

355 """ 

356 

357 CATEGORY: Optional[Union[str, "CommandCategory"]] = None 

358 """ 

359 If used, the command will be under this top-level category. 

360 Use the same string for several commands to group them. 

361 This hierarchy only used for the adhoc interface, not the chat command 

362 interface. 

363 """ 

364 

365 subclasses = list[Type["Command"]]() 

366 

367 def __init__(self, xmpp: "BaseGateway") -> None: 

368 self.xmpp = xmpp 

369 

370 def __init_subclass__(cls, **kwargs: Any) -> None: 

371 # store subclasses so subclassing is enough for the command to be 

372 # picked up by slidge 

373 cls.subclasses.append(cls) 

374 

375 async def run( 

376 self, session: Optional["BaseSession[Any, Any]"], ifrom: JID, *args: str 

377 ) -> CommandResponseType: 

378 """ 

379 Entry point of the command 

380 

381 :param session: If triggered by a registered user, its slidge Session 

382 :param ifrom: JID of the command-triggering entity 

383 :param args: When triggered via chatbot type message, additional words 

384 after the CHAT_COMMAND string was passed 

385 

386 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

387 """ 

388 raise XMPPError("feature-not-implemented") 

389 

390 def _get_session(self, jid: JID) -> Optional["BaseSession[Any, Any]"]: 

391 return self.xmpp.get_session_from_jid(jid) 

392 

393 def __can_use_command(self, jid: JID): 

394 j = jid.bare 

395 return self.xmpp.jid_validator.match(j) or j in config.ADMINS 

396 

397 def raise_if_not_authorized( 

398 self, 

399 jid: JID, 

400 fetch_session: bool = True, 

401 session: Optional["BaseSession[Any, Any]"] = None, 

402 ) -> Optional["BaseSession[Any, Any]"]: 

403 """ 

404 Raise an appropriate error is jid is not authorized to use the command 

405 

406 :param jid: jid of the entity trying to access the command 

407 :param fetch_session: 

408 :param session: 

409 

410 :return:session of JID if it exists 

411 """ 

412 if not self.__can_use_command(jid): 

413 raise XMPPError( 

414 "bad-request", "Your JID is not allowed to use this gateway." 

415 ) 

416 if fetch_session: 

417 session = self._get_session(jid) 

418 

419 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

420 raise XMPPError("not-authorized") 

421 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

422 raise XMPPError( 

423 "bad-request", "This is only available for non-users. Unregister first." 

424 ) 

425 elif self.ACCESS == CommandAccess.USER and session is None: 

426 raise XMPPError( 

427 "forbidden", 

428 "This is only available for users that are registered to this gateway", 

429 ) 

430 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

431 if session is None or session.logged: 

432 raise XMPPError( 

433 "forbidden", 

434 ( 

435 "This is only available for users that are not logged to the" 

436 " legacy service" 

437 ), 

438 ) 

439 elif self.ACCESS == CommandAccess.USER_LOGGED: 

440 if session is None or not session.logged: 

441 raise XMPPError( 

442 "forbidden", 

443 ( 

444 "This is only available when you are logged in to the legacy" 

445 " service" 

446 ), 

447 ) 

448 return session 

449 

450 

451def is_admin(jid: JidStr) -> bool: 

452 return JID(jid).bare in config.ADMINS