Coverage for slidge/util/conf.py: 97%

146 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-04 08:17 +0000

1import logging 

2from functools import cached_property 

3from types import GenericAlias 

4from typing import Any, Optional, Union, cast, get_args, get_origin, get_type_hints 

5 

6import configargparse 

7 

8 

9class Option: 

10 DOC_SUFFIX = "__DOC" 

11 DYNAMIC_DEFAULT_SUFFIX = "__DYNAMIC_DEFAULT" 

12 SHORT_SUFFIX = "__SHORT" 

13 

14 def __init__(self, parent: "ConfigModule", name: str) -> None: 

15 self.parent = parent 

16 self.config_obj = parent.config_obj 

17 self.name = name 

18 

19 @cached_property 

20 def doc(self) -> str: 

21 return getattr(self.config_obj, self.name + self.DOC_SUFFIX) # type:ignore 

22 

23 @cached_property 

24 def required(self) -> bool: 

25 return not hasattr( 

26 self.config_obj, self.name + self.DYNAMIC_DEFAULT_SUFFIX 

27 ) and not hasattr(self.config_obj, self.name) 

28 

29 @cached_property 

30 def default(self) -> Any: 

31 return getattr(self.config_obj, self.name, None) 

32 

33 @cached_property 

34 def short(self) -> str | None: 

35 return getattr(self.config_obj, self.name + self.SHORT_SUFFIX, None) 

36 

37 @cached_property 

38 def nargs(self) -> str | int | None: 

39 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default)) 

40 

41 if isinstance(type_, GenericAlias): 

42 args = get_args(type_) 

43 if args[1] is Ellipsis: 

44 return "*" 

45 else: 

46 return len(args) 

47 return None 

48 

49 @cached_property 

50 def type(self) -> Any: 

51 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default)) 

52 

53 if _is_optional(type_): 

54 type_ = get_args(type_)[0] 

55 elif isinstance(type_, GenericAlias): 

56 args = get_args(type_) 

57 type_ = args[0] 

58 

59 return type_ 

60 

61 @cached_property 

62 def names(self) -> list[str]: 

63 res = ["--" + self.name.lower().replace("_", "-")] 

64 if s := self.short: 

65 res.append("-" + s) 

66 return res 

67 

68 @cached_property 

69 def kwargs(self) -> dict[str, Any]: 

70 kwargs = dict( 

71 required=self.required, 

72 help=self.doc, 

73 env_var=self.name_to_env_var(), 

74 ) 

75 t = self.type 

76 if t is bool: 

77 if self.default: 

78 kwargs["action"] = "store_false" 

79 else: 

80 kwargs["action"] = "store_true" 

81 else: 

82 kwargs["type"] = t 

83 if self.required: 

84 kwargs["required"] = True 

85 else: 

86 kwargs["default"] = self.default 

87 if n := self.nargs: 

88 kwargs["nargs"] = n 

89 return kwargs 

90 

91 def name_to_env_var(self) -> str: 

92 return self.parent.ENV_VAR_PREFIX + self.name 

93 

94 

95class ConfigModule: 

96 ENV_VAR_PREFIX = "SLIDGE_" 

97 

98 def __init__( 

99 self, 

100 config_obj: Any, 

101 parser: Optional[configargparse.ArgumentParser] = None, 

102 skip_options: tuple[str, ...] = (), 

103 ) -> None: 

104 self.config_obj = config_obj 

105 if parser is None: 

106 parser = configargparse.ArgumentParser() 

107 self.parser = parser 

108 

109 self.skip_options = skip_options 

110 self.add_options_to_parser(skip_options) 

111 

112 def _list_options(self) -> set[str]: 

113 return { 

114 o 

115 for o in (set(dir(self.config_obj)) | set(get_type_hints(self.config_obj))) 

116 if o.upper() == o 

117 and not o.startswith("_") 

118 and "__" not in o 

119 and o.lower() not in self.skip_options 

120 } 

121 

122 def set_conf( 

123 self, argv: Optional[list[str]] = None 

124 ) -> tuple[configargparse.Namespace, list[str]]: 

125 if argv is not None: 

126 # this is ugly, but necessary because for plugin config, we used 

127 # remaining argv. 

128 # when using (a) .ini file(s), for bool options, we end-up with 

129 # remaining pseudo-argv such as --some-bool-opt=true when we really 

130 # should have just --some-bool-opt 

131 # TODO: get rid of configargparse and make this cleaner 

132 options_long = {o.name: o for o in self.options} 

133 no_explicit_bool = [] 

134 skip_next = False 

135 for a, aa in zip(argv, argv[1:] + [""]): 

136 if skip_next: 

137 skip_next = False 

138 continue 

139 force_keep = False 

140 if "=" in a: 

141 real_name, _value = a.split("=") 

142 opt: Optional[Option] = options_long.get( 

143 _argv_to_option_name(real_name) 

144 ) 

145 if opt and opt.type is bool: 

146 if opt.default: 

147 if _value in _TRUEISH or not _value: 

148 continue 

149 else: 

150 a = real_name 

151 force_keep = True 

152 else: 

153 if _value in _TRUEISH: 

154 a = real_name 

155 force_keep = True 

156 else: 

157 continue 

158 else: 

159 upper = _argv_to_option_name(a) 

160 opt = options_long.get(upper) 

161 if opt and opt.type is bool: 

162 if _argv_to_option_name(aa) not in options_long: 

163 log.debug("Removing %s from argv", aa) 

164 skip_next = True 

165 

166 if opt: 

167 if opt.type is bool: 

168 if force_keep or not opt.default: 

169 no_explicit_bool.append(a) 

170 else: 

171 no_explicit_bool.append(a) 

172 else: 

173 no_explicit_bool.append(a) 

174 log.debug("Removed boolean values from %s to %s", argv, no_explicit_bool) 

175 argv = no_explicit_bool 

176 

177 args, rest = self.parser.parse_known_args(argv) 

178 self.update_dynamic_defaults(args) 

179 for name in self._list_options(): 

180 value = getattr(args, name.lower()) 

181 log.debug("Setting '%s' to %r", name, value) 

182 setattr(self.config_obj, name, value) 

183 return args, rest 

184 

185 @cached_property 

186 def options(self) -> list[Option]: 

187 res = [] 

188 for opt in self._list_options(): 

189 res.append(Option(self, opt)) 

190 return res 

191 

192 def add_options_to_parser(self, skip_options: tuple[str, ...]) -> None: 

193 skip_options = tuple(o.lower() for o in skip_options) 

194 p = self.parser 

195 for o in sorted(self.options, key=lambda x: (not x.required, x.name)): 

196 if o.name.lower() in skip_options: 

197 continue 

198 p.add_argument(*o.names, **o.kwargs) 

199 

200 def update_dynamic_defaults(self, args: configargparse.Namespace) -> None: 

201 pass 

202 

203 

204def _is_optional(t: Any) -> bool: 

205 if get_origin(t) is Union: 

206 args = get_args(t) 

207 if len(args) == 2 and isinstance(None, args[1]): 

208 return True 

209 return False 

210 

211 

212def _argv_to_option_name(arg: str) -> str: 

213 return arg.upper().removeprefix("--").replace("-", "_") 

214 

215 

216_TRUEISH = {"true", "True", "1", "on", "enabled"} 

217 

218 

219log = logging.getLogger(__name__)