Coverage for slidge / core / dispatcher / disco.py: 91%

47 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 05:07 +0000

1import logging 

2from typing import TYPE_CHECKING, Any 

3 

4import sqlalchemy as sa 

5from slixmpp.exceptions import XMPPError 

6from slixmpp.plugins.xep_0030.stanza.info import DiscoInfo 

7from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

8from slixmpp.types import OptJid 

9 

10from ...db.models import Room 

11from .util import DispatcherMixin 

12 

13if TYPE_CHECKING: 

14 from slidge.core.gateway import BaseGateway 

15 from slidge.util.types import AnySession 

16 

17 

18class DiscoMixin(DispatcherMixin): 

19 __slots__: list[str] = [] 

20 

21 def __init__(self, xmpp: "BaseGateway") -> None: 

22 super().__init__(xmpp) 

23 

24 xmpp.plugin["xep_0030"].set_node_handler( 

25 "get_info", 

26 jid=None, 

27 node=None, 

28 handler=self.get_info, 

29 ) 

30 

31 xmpp.plugin["xep_0030"].set_node_handler( 

32 "get_items", 

33 jid=None, 

34 node=None, 

35 handler=self.get_items, 

36 ) 

37 

38 async def get_info( 

39 self, 

40 jid: OptJid, 

41 node: str | None, 

42 ifrom: OptJid, 

43 data: Any, # noqa:ANN401 

44 ) -> DiscoInfo | None: 

45 if ifrom == self.xmpp.boundjid.bare or jid in (self.xmpp.boundjid.bare, None): 

46 return self.xmpp.plugin["xep_0030"].static.get_info(jid, node, ifrom, data) 

47 

48 if ifrom is None: 

49 raise XMPPError("subscription-required") 

50 

51 assert jid is not None 

52 session = await self._get_session_from_jid(jid=ifrom) 

53 

54 log.debug("Looking for entity: %s", jid) 

55 

56 entity = await session.get_contact_or_group_or_participant(jid) 

57 

58 if entity is None: 

59 raise XMPPError("item-not-found") 

60 

61 return await entity.get_disco_info(jid, node) 

62 

63 async def get_items( 

64 self, 

65 jid: OptJid, 

66 node: str | None, 

67 ifrom: OptJid, 

68 data: Any, # noqa:ANN401 

69 ) -> DiscoItems: 

70 if ifrom is None: 

71 raise XMPPError("bad-request") 

72 

73 assert ifrom is not None 

74 session = await self._get_session_from_jid(ifrom) 

75 

76 if jid == self.xmpp.boundjid.bare or not jid: 

77 if node: 

78 return DiscoItems() 

79 else: 

80 return self.__list_rooms(session) 

81 

82 entity = await session.get_contact_or_group_or_participant(jid) 

83 

84 if entity is None: 

85 raise XMPPError("item-not-found") 

86 

87 return await entity.get_disco_items(node) 

88 

89 def __list_rooms(self, session: "AnySession") -> DiscoItems: 

90 d = DiscoItems() 

91 with self.xmpp.store.session() as orm: 

92 for room in orm.execute( 

93 sa.select(Room) 

94 .options(sa.orm.load_only(Room.jid, Room.name)) 

95 .filter_by(user=session.user) 

96 .order_by(Room.name) 

97 ).scalars(): 

98 d.add_item(room.jid, name=room.name) 

99 

100 return d 

101 

102 

103log = logging.getLogger(__name__)