Coverage for slidge/core/dispatcher/caps.py: 90%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import logging 

2from typing import TYPE_CHECKING 

3 

4from slixmpp import Presence 

5from slixmpp.exceptions import XMPPError 

6from slixmpp.xmlstream import StanzaBase 

7 

8from .util import DispatcherMixin 

9 

10if TYPE_CHECKING: 

11 from slidge.core.gateway import BaseGateway 

12 

13 

14class CapsMixin(DispatcherMixin): 

15 def __init__(self, xmpp: "BaseGateway"): 

16 super().__init__(xmpp) 

17 xmpp.del_filter("out", xmpp.plugin["xep_0115"]._filter_add_caps) 

18 xmpp.add_filter("out", self._filter_add_caps) # type:ignore 

19 

20 async def _filter_add_caps(self, stanza: StanzaBase) -> StanzaBase: 

21 # we rolled our own "add caps on presences" filter because 

22 # there is too much magic happening in slixmpp 

23 # anyway, we probably want to roll our own "dynamic disco"/caps 

24 # module in the long run, so it's a step in this direction 

25 if not isinstance(stanza, Presence): 

26 return stanza 

27 

28 if stanza.get_plugin("caps", check=True): 

29 return stanza 

30 

31 if stanza["type"] not in ("available", "chat", "away", "dnd", "xa"): 

32 return stanza 

33 

34 pfrom = stanza.get_from() 

35 

36 caps = self.xmpp.plugin["xep_0115"] 

37 

38 if pfrom != self.xmpp.boundjid.bare: 

39 try: 

40 session = self.xmpp.get_session_from_jid(stanza.get_to()) 

41 except XMPPError: 

42 log.debug("not adding caps 1") 

43 return stanza 

44 

45 if session is None: 

46 return stanza 

47 

48 await session.ready 

49 

50 try: 

51 contact = await session.contacts.by_jid(pfrom) 

52 except XMPPError: 

53 return stanza 

54 ver = await contact.get_caps_ver(pfrom) 

55 else: 

56 ver = await caps.get_verstring(pfrom) 

57 

58 log.debug("Ver: %s", ver) 

59 

60 if ver: 

61 stanza["caps"]["node"] = caps.caps_node 

62 stanza["caps"]["hash"] = caps.hash 

63 stanza["caps"]["ver"] = ver 

64 return stanza 

65 

66 

67log = logging.getLogger(__name__)