Coverage for slidge/util/conf.py: 98%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-07 05:11 +0000

1import logging 

2from functools import cached_property 

3from types import GenericAlias 

4from typing import Optional, Union, get_args, get_origin, get_type_hints 

5 

6import configargparse 

7 

8 

9class Option: 

10 DOC_SUFFIX = "__DOC" 

11 DYNAMIC_DEFAULT_SUFFIX = "__DYNAMIC_DEFAULT" 

12 SHORT_SUFFIX = "__SHORT" 

13 

14 def __init__(self, parent: "ConfigModule", name: str): 

15 self.parent = parent 

16 self.config_obj = parent.config_obj 

17 self.name = name 

18 

19 @cached_property 

20 def doc(self): 

21 return getattr(self.config_obj, self.name + self.DOC_SUFFIX) 

22 

23 @cached_property 

24 def required(self): 

25 return not hasattr( 

26 self.config_obj, self.name + self.DYNAMIC_DEFAULT_SUFFIX 

27 ) and not hasattr(self.config_obj, self.name) 

28 

29 @cached_property 

30 def default(self): 

31 return getattr(self.config_obj, self.name, None) 

32 

33 @cached_property 

34 def short(self): 

35 return getattr(self.config_obj, self.name + self.SHORT_SUFFIX, None) 

36 

37 @cached_property 

38 def nargs(self): 

39 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default)) 

40 

41 if isinstance(type_, GenericAlias): 

42 args = get_args(type_) 

43 if args[1] is Ellipsis: 

44 return "*" 

45 else: 

46 return len(args) 

47 

48 @cached_property 

49 def type(self): 

50 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default)) 

51 

52 if _is_optional(type_): 

53 type_ = get_args(type_)[0] 

54 elif isinstance(type_, GenericAlias): 

55 args = get_args(type_) 

56 type_ = args[0] 

57 

58 return type_ 

59 

60 @cached_property 

61 def names(self): 

62 res = ["--" + self.name.lower().replace("_", "-")] 

63 if s := self.short: 

64 res.append("-" + s) 

65 return res 

66 

67 @cached_property 

68 def kwargs(self): 

69 kwargs = dict( 

70 required=self.required, 

71 help=self.doc, 

72 env_var=self.name_to_env_var(), 

73 ) 

74 t = self.type 

75 if t is bool: 

76 if self.default: 

77 kwargs["action"] = "store_false" 

78 else: 

79 kwargs["action"] = "store_true" 

80 else: 

81 kwargs["type"] = t 

82 if self.required: 

83 kwargs["required"] = True 

84 else: 

85 kwargs["default"] = self.default 

86 if n := self.nargs: 

87 kwargs["nargs"] = n 

88 return kwargs 

89 

90 def name_to_env_var(self): 

91 return self.parent.ENV_VAR_PREFIX + self.name 

92 

93 

94class ConfigModule: 

95 ENV_VAR_PREFIX = "SLIDGE_" 

96 

97 def __init__( 

98 self, config_obj, parser: Optional[configargparse.ArgumentParser] = None 

99 ): 

100 self.config_obj = config_obj 

101 if parser is None: 

102 parser = configargparse.ArgumentParser() 

103 self.parser = parser 

104 

105 self.add_options_to_parser() 

106 

107 def _list_options(self): 

108 return { 

109 o 

110 for o in (set(dir(self.config_obj)) | set(get_type_hints(self.config_obj))) 

111 if o.upper() == o and not o.startswith("_") and "__" not in o 

112 } 

113 

114 def set_conf(self, argv: Optional[list[str]] = None): 

115 if argv is not None: 

116 # this is ugly, but necessary because for plugin config, we used 

117 # remaining argv. 

118 # when using (a) .ini file(s), for bool options, we end-up with 

119 # remaining pseudo-argv such as --some-bool-opt=true when we really 

120 # should have just --some-bool-opt 

121 # TODO: get rid of configargparse and make this cleaner 

122 options_long = {o.name: o for o in self.options} 

123 no_explicit_bool = [] 

124 skip_next = False 

125 for a, aa in zip(argv, argv[1:] + [""]): 

126 if skip_next: 

127 skip_next = False 

128 continue 

129 force_keep = False 

130 if "=" in a: 

131 real_name, _value = a.split("=") 

132 opt: Optional[Option] = options_long.get( 

133 _argv_to_option_name(real_name) 

134 ) 

135 if opt and opt.type is bool: 

136 if opt.default: 

137 if _value in _TRUEISH or not _value: 

138 continue 

139 else: 

140 a = real_name 

141 force_keep = True 

142 else: 

143 if _value in _TRUEISH: 

144 a = real_name 

145 force_keep = True 

146 else: 

147 continue 

148 else: 

149 upper = _argv_to_option_name(a) 

150 opt = options_long.get(upper) 

151 if opt and opt.type is bool: 

152 if _argv_to_option_name(aa) not in options_long: 

153 log.debug("Removing %s from argv", aa) 

154 skip_next = True 

155 

156 if opt: 

157 if opt.type is bool: 

158 if force_keep or not opt.default: 

159 no_explicit_bool.append(a) 

160 else: 

161 no_explicit_bool.append(a) 

162 else: 

163 no_explicit_bool.append(a) 

164 log.debug("Removed boolean values from %s to %s", argv, no_explicit_bool) 

165 argv = no_explicit_bool 

166 

167 args, rest = self.parser.parse_known_args(argv) 

168 self.update_dynamic_defaults(args) 

169 for name in self._list_options(): 

170 value = getattr(args, name.lower()) 

171 log.debug("Setting '%s' to %r", name, value) 

172 setattr(self.config_obj, name, value) 

173 return args, rest 

174 

175 @cached_property 

176 def options(self) -> list[Option]: 

177 res = [] 

178 for opt in self._list_options(): 

179 res.append(Option(self, opt)) 

180 return res 

181 

182 def add_options_to_parser(self): 

183 p = self.parser 

184 for o in sorted(self.options, key=lambda x: (not x.required, x.name)): 

185 p.add_argument(*o.names, **o.kwargs) 

186 

187 def update_dynamic_defaults(self, args): 

188 pass 

189 

190 

191def _is_optional(t): 

192 if get_origin(t) is Union: 

193 args = get_args(t) 

194 if len(args) == 2 and isinstance(None, args[1]): 

195 return True 

196 return False 

197 

198 

199def _argv_to_option_name(arg: str): 

200 return arg.upper().removeprefix("--").replace("-", "_") 

201 

202 

203_TRUEISH = {"true", "True", "1", "on", "enabled"} 

204 

205 

206log = logging.getLogger(__name__)