Coverage for slidge / command / base.py: 94%

204 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-03-13 22:59 +0000

1from abc import ABC 

2from collections.abc import Awaitable, Callable, Iterable, Sequence 

3from dataclasses import dataclass, field 

4from enum import Enum 

5from typing import ( 

6 TYPE_CHECKING, 

7 Any, 

8 Generic, 

9 TypedDict, 

10 Union, 

11) 

12 

13from slixmpp import JID 

14from slixmpp.exceptions import XMPPError 

15from slixmpp.plugins.xep_0004 import Form as SlixForm # type: ignore[attr-defined] 

16from slixmpp.plugins.xep_0004 import FormField as SlixFormField 

17from slixmpp.types import JidStr 

18 

19from ..core import config 

20from ..util.types import AnyBaseSession, FieldType, SessionType 

21 

22NODE_PREFIX = "https://slidge.im/command/core/" 

23 

24if TYPE_CHECKING: 

25 from ..core.gateway import BaseGateway 

26 from .categories import CommandCategory 

27 

28 

29HandlerType = ( 

30 Callable[[AnyBaseSession, JID], "CommandResponseType"] 

31 | Callable[[AnyBaseSession, JID], Awaitable["CommandResponseType"]] 

32) 

33 

34FormValues = dict[str, str | JID | bool] 

35 

36 

37@dataclass 

38class TableResult: 

39 """ 

40 Structured data as the result of a command 

41 """ 

42 

43 fields: Sequence["FormField"] 

44 """ 

45 The 'columns names' of the table. 

46 """ 

47 items: Sequence[dict[str, str | JID]] 

48 """ 

49 The rows of the table. Each row is a dict where keys are the fields ``var`` 

50 attribute. 

51 """ 

52 description: str 

53 """ 

54 A description of the content of the table. 

55 """ 

56 

57 jids_are_mucs: bool = False 

58 

59 def get_xml(self) -> SlixForm: 

60 """ 

61 Get a slixmpp "form" (with <reported> header)to represent the data 

62 

63 :return: some XML 

64 """ 

65 form = SlixForm() # type: ignore[no-untyped-call] 

66 form["type"] = "result" 

67 form["title"] = self.description 

68 for f in self.fields: 

69 form.add_reported(f.var, label=f.label, type=f.type) # type: ignore[no-untyped-call] 

70 for item in self.items: 

71 form.add_item({k: str(v) for k, v in item.items()}) # type: ignore[no-untyped-call] 

72 return form 

73 

74 

75@dataclass 

76class SearchResult(TableResult): 

77 """ 

78 Results of the search command (search for contacts via Jabber Search) 

79 

80 Return type of :meth:`BaseSession.search`. 

81 """ 

82 

83 description: str = "Contact search results" 

84 

85 

86@dataclass 

87class Confirmation(Generic[SessionType]): 

88 """ 

89 A confirmation 'dialog' 

90 """ 

91 

92 prompt: str 

93 """ 

94 The text presented to the command triggering user 

95 """ 

96 handler: Callable[ 

97 [SessionType | None, JID], 

98 Awaitable["CommandResponseType"], 

99 ] 

100 """ 

101 An async function that should return a ResponseType 

102 """ 

103 success: str | None = None 

104 """ 

105 Text in case of success, used if handler does not return anything 

106 """ 

107 handler_args: Iterable[Any] = field(default_factory=list) 

108 """ 

109 arguments passed to the handler 

110 """ 

111 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

112 """ 

113 keyword arguments passed to the handler 

114 """ 

115 

116 def get_form(self) -> SlixForm: 

117 """ 

118 Get the slixmpp form 

119 

120 :return: some xml 

121 """ 

122 form = SlixForm() # type: ignore[no-untyped-call] 

123 form["type"] = "form" 

124 form["title"] = self.prompt 

125 form.append( 

126 FormField( 

127 "confirm", type="boolean", value="true", label="Confirm" 

128 ).get_xml() 

129 ) 

130 return form 

131 

132 

133@dataclass 

134class Form(Generic[SessionType]): 

135 """ 

136 A form, to request user input 

137 """ 

138 

139 title: str 

140 instructions: str 

141 fields: Sequence["FormField"] 

142 handler: Callable[ 

143 [FormValues, SessionType | None, JID], 

144 Awaitable["CommandResponseType"], 

145 ] 

146 handler_args: Iterable[Any] = field(default_factory=list) 

147 handler_kwargs: dict[str, Any] = field(default_factory=dict) 

148 timeout_handler: Callable[[], None] | None = None 

149 

150 def get_values( 

151 self, slix_form: SlixForm 

152 ) -> dict[str, list[str] | list[JID] | str | JID | bool | None]: 

153 """ 

154 Parse form submission 

155 

156 :param slix_form: the xml received as the submission of a form 

157 :return: A dict where keys=field.var and values are either strings 

158 or JIDs (if field.type=jid-single) 

159 """ 

160 str_values: dict[str, str] = slix_form.get_values() # type: ignore[no-untyped-call] 

161 values = {} 

162 for f in self.fields: 

163 values[f.var] = f.validate(str_values.get(f.var)) 

164 return values 

165 

166 def get_xml(self) -> SlixForm: 

167 """ 

168 Get the slixmpp "form" 

169 

170 :return: some XML 

171 """ 

172 form = SlixForm() # type: ignore[no-untyped-call] 

173 form["type"] = "form" 

174 form["title"] = self.title 

175 form["instructions"] = self.instructions 

176 for fi in self.fields: 

177 form.append(fi.get_xml()) 

178 return form 

179 

180 

181class CommandAccess(int, Enum): 

182 """ 

183 Defines who can access a given Command 

184 """ 

185 

186 ADMIN_ONLY = 0 

187 USER = 1 

188 USER_LOGGED = 2 

189 USER_NON_LOGGED = 3 

190 NON_USER = 4 

191 ANY = 5 

192 

193 

194class Option(TypedDict): 

195 """ 

196 Options to be used for ``FormField``s of type ``list-*`` 

197 """ 

198 

199 label: str 

200 value: str 

201 

202 

203# TODO: support forms validation XEP-0122 

204@dataclass 

205class FormField: 

206 """ 

207 Represents a field of the form that a user will see when registering to the gateway 

208 via their XMPP client. 

209 """ 

210 

211 var: str = "" 

212 """ 

213 Internal name of the field, will be used to retrieve via :py:attr:`slidge.GatewayUser.registration_form` 

214 """ 

215 label: str | None = None 

216 """Description of the field that the user will see""" 

217 required: bool = False 

218 """Whether this field is mandatory or not""" 

219 private: bool = False 

220 """ 

221 For sensitive info that should not be displayed on screen while the user types. 

222 Forces field_type to "text-private" 

223 """ 

224 type: FieldType = "text-single" 

225 """Type of the field, see `XEP-0004 <https://xmpp.org/extensions/xep-0004.html#protocol-fieldtypes>`_""" 

226 value: str = "" 

227 """Pre-filled value. Will be automatically pre-filled if a registered user modifies their subscription""" 

228 options: list[Option] | None = None 

229 

230 image_url: str | None = None 

231 """An image associated to this field, eg, a QR code""" 

232 

233 def __post_init__(self) -> None: 

234 if self.private: 

235 self.type = "text-private" 

236 

237 def __acceptable_options(self) -> list[str]: 

238 if self.options is None: 

239 raise RuntimeError 

240 return [x["value"] for x in self.options] 

241 

242 def validate( 

243 self, value: str | list[str] | None 

244 ) -> list[str] | list[JID] | str | JID | bool | None: 

245 """ 

246 Raise appropriate XMPPError if a given value is valid for this field 

247 

248 :param value: The value to test 

249 :return: The same value OR a JID if ``self.type=jid-single`` 

250 """ 

251 if isinstance(value, list) and not self.type.endswith("multi"): 

252 raise XMPPError("not-acceptable", "A single value was expected") 

253 

254 if self.type in ("list-multi", "jid-multi", "text-multi"): 

255 if not value: 

256 value = [] 

257 if isinstance(value, list): 

258 if self.type == "text-multi": 

259 return value 

260 return self.__validate_list_multi(value) 

261 else: 

262 raise XMPPError("not-acceptable", "Multiple values was expected") 

263 

264 assert isinstance(value, (str, bool, JID)) or value is None 

265 

266 if self.required and value is None: 

267 raise XMPPError("not-acceptable", f"Missing field: '{self.label}'") 

268 

269 if value is None: 

270 return None 

271 

272 if self.type == "jid-single": 

273 try: 

274 return JID(value) 

275 except ValueError: 

276 raise XMPPError("not-acceptable", f"Not a valid JID: '{value}'") 

277 

278 elif self.type == "list-single": 

279 if value not in self.__acceptable_options(): 

280 raise XMPPError("not-acceptable", f"Not a valid option: '{value}'") 

281 

282 elif self.type == "boolean": 

283 return value.lower() in ("1", "true") if isinstance(value, str) else value 

284 

285 return value 

286 

287 def __validate_list_multi(self, value: list[str]) -> list[str] | list[JID]: 

288 for v in value: 

289 if v not in self.__acceptable_options(): 

290 raise XMPPError("not-acceptable", f"Not a valid option: '{v}'") 

291 if self.type == "list-multi": 

292 return value 

293 return [JID(v) for v in value] 

294 

295 def get_xml(self) -> SlixFormField: 

296 """ 

297 Get the field in slixmpp format 

298 

299 :return: some XML 

300 """ 

301 f = SlixFormField() 

302 f["var"] = self.var 

303 f["label"] = self.label 

304 f["required"] = self.required 

305 f["type"] = self.type 

306 if self.options: 

307 for o in self.options: 

308 f.add_option(**o) # type: ignore[no-untyped-call] 

309 f["value"] = self.value 

310 if self.image_url: 

311 f["media"].add_uri(self.image_url, itype="image/png") 

312 return f 

313 

314 

315CommandResponseType = TableResult | Confirmation | Form | str | None 

316 

317 

318class Command(ABC, Generic[SessionType]): 

319 """ 

320 Abstract base class to implement gateway commands (chatbot and ad-hoc) 

321 """ 

322 

323 NAME: str = NotImplemented 

324 """ 

325 Friendly name of the command, eg: "do something with stuff" 

326 """ 

327 HELP: str = NotImplemented 

328 """ 

329 Long description of what the command does 

330 """ 

331 NODE: str = NotImplemented 

332 """ 

333 Name of the node used for ad-hoc commands 

334 """ 

335 CHAT_COMMAND: str = NotImplemented 

336 """ 

337 Text to send to the gateway to trigger the command via a message 

338 """ 

339 

340 ACCESS: "CommandAccess" = NotImplemented 

341 """ 

342 Who can use this command 

343 """ 

344 

345 CATEGORY: Union[str, "CommandCategory"] | None = None 

346 """ 

347 If used, the command will be under this top-level category. 

348 Use the same string for several commands to group them. 

349 This hierarchy only used for the adhoc interface, not the chat command 

350 interface. 

351 """ 

352 

353 subclasses = list[type["Command"]]() 

354 

355 def __init__(self, xmpp: "BaseGateway") -> None: 

356 self.xmpp = xmpp 

357 

358 def __init_subclass__(cls, **kwargs: Any) -> None: 

359 # store subclasses so subclassing is enough for the command to be 

360 # picked up by slidge 

361 cls.subclasses.append(cls) 

362 

363 async def run( 

364 self, 

365 session: SessionType | None, 

366 ifrom: JID, 

367 *args: str, 

368 ) -> CommandResponseType: 

369 """ 

370 Entry point of the command 

371 

372 :param session: If triggered by a registered user, its slidge Session 

373 :param ifrom: JID of the command-triggering entity 

374 :param args: When triggered via chatbot type message, additional words 

375 after the CHAT_COMMAND string was passed 

376 

377 :return: Either a TableResult, a Form, a Confirmation, a text, or None 

378 """ 

379 raise XMPPError("feature-not-implemented") 

380 

381 def _get_session(self, jid: JID) -> SessionType | None: 

382 return self.xmpp.get_session_from_jid(jid) 

383 

384 def __can_use_command(self, jid: JID) -> bool: 

385 j = jid.bare 

386 return bool(self.xmpp.jid_validator.match(j) or j in config.ADMINS) 

387 

388 def raise_if_not_authorized( 

389 self, 

390 jid: JID, 

391 fetch_session: bool = True, 

392 session: SessionType | None = None, 

393 ) -> SessionType | None: 

394 """ 

395 Raise an appropriate error is jid is not authorized to use the command 

396 

397 :param jid: jid of the entity trying to access the command 

398 :param fetch_session: 

399 :param session: 

400 

401 :return:session of JID if it exists 

402 """ 

403 if not self.__can_use_command(jid): 

404 raise XMPPError( 

405 "bad-request", "Your JID is not allowed to use this gateway." 

406 ) 

407 if fetch_session: 

408 session = self._get_session(jid) 

409 

410 if self.ACCESS == CommandAccess.ADMIN_ONLY and not is_admin(jid): 

411 raise XMPPError("not-authorized") 

412 elif self.ACCESS == CommandAccess.NON_USER and session is not None: 

413 raise XMPPError( 

414 "bad-request", "This is only available for non-users. Unregister first." 

415 ) 

416 elif self.ACCESS == CommandAccess.USER and session is None: 

417 raise XMPPError( 

418 "forbidden", 

419 "This is only available for users that are registered to this gateway", 

420 ) 

421 elif self.ACCESS == CommandAccess.USER_NON_LOGGED: 

422 if session is None or session.logged: 

423 raise XMPPError( 

424 "forbidden", 

425 ( 

426 "This is only available for users that are not logged to the" 

427 " legacy service" 

428 ), 

429 ) 

430 elif self.ACCESS == CommandAccess.USER_LOGGED: 

431 if session is None or not session.logged: 

432 raise XMPPError( 

433 "forbidden", 

434 ( 

435 "This is only available when you are logged in to the legacy" 

436 " service" 

437 ), 

438 ) 

439 return session 

440 

441 

442def is_admin(jid: JidStr) -> bool: 

443 return JID(jid).bare in config.ADMINS