Coverage for slidge/core/dispatcher/disco.py: 94%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import logging 

2from typing import TYPE_CHECKING, Any, Optional 

3 

4from slixmpp.exceptions import XMPPError 

5from slixmpp.plugins.xep_0030.stanza.items import DiscoItems 

6from slixmpp.types import OptJid 

7 

8from .util import DispatcherMixin 

9 

10if TYPE_CHECKING: 

11 from slidge.core.gateway import BaseGateway 

12 

13 

14class DiscoMixin(DispatcherMixin): 

15 def __init__(self, xmpp: "BaseGateway"): 

16 super().__init__(xmpp) 

17 

18 xmpp.plugin["xep_0030"].set_node_handler( 

19 "get_info", 

20 jid=None, 

21 node=None, 

22 handler=self.get_info, 

23 ) 

24 

25 xmpp.plugin["xep_0030"].set_node_handler( 

26 "get_items", 

27 jid=None, 

28 node=None, 

29 handler=self.get_items, 

30 ) 

31 

32 async def get_info( 

33 self, jid: OptJid, node: Optional[str], ifrom: OptJid, data: Any 

34 ): 

35 if ifrom == self.xmpp.boundjid.bare or jid in (self.xmpp.boundjid.bare, None): 

36 return self.xmpp.plugin["xep_0030"].static.get_info(jid, node, ifrom, data) 

37 

38 if ifrom is None: 

39 raise XMPPError("subscription-required") 

40 

41 assert jid is not None 

42 session = await self._get_session_from_jid(jid=ifrom) 

43 

44 log.debug("Looking for entity: %s", jid) 

45 

46 entity = await session.get_contact_or_group_or_participant(jid) 

47 

48 if entity is None: 

49 raise XMPPError("item-not-found") 

50 

51 return await entity.get_disco_info(jid, node) 

52 

53 async def get_items( 

54 self, jid: OptJid, node: Optional[str], ifrom: OptJid, data: Any 

55 ): 

56 if ifrom is None: 

57 raise XMPPError("bad-request") 

58 

59 if jid != self.xmpp.boundjid.bare: 

60 return DiscoItems() 

61 

62 assert ifrom is not None 

63 session = await self._get_session_from_jid(ifrom) 

64 

65 d = DiscoItems() 

66 for room in self.xmpp.store.rooms.get_all_jid_and_names(session.user_pk): 

67 d.add_item(room.jid, name=room.name) 

68 

69 return d 

70 

71 

72log = logging.getLogger(__name__)