Coverage for slidge/util/conf.py: 97%

146 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-26 19:34 +0000

1import logging 

2from functools import cached_property 

3from types import GenericAlias 

4from typing import Any, Optional, Union, cast, get_args, get_origin, get_type_hints 

5 

6import configargparse 

7 

8 

9class Option: 

10 DOC_SUFFIX = "__DOC" 

11 DYNAMIC_DEFAULT_SUFFIX = "__DYNAMIC_DEFAULT" 

12 SHORT_SUFFIX = "__SHORT" 

13 

14 def __init__(self, parent: "ConfigModule", name: str) -> None: 

15 self.parent = parent 

16 self.config_obj = parent.config_obj 

17 self.name = name 

18 

19 @cached_property 

20 def doc(self) -> str: 

21 return getattr(self.config_obj, self.name + self.DOC_SUFFIX) # type:ignore 

22 

23 @cached_property 

24 def required(self) -> bool: 

25 return not hasattr( 

26 self.config_obj, self.name + self.DYNAMIC_DEFAULT_SUFFIX 

27 ) and not hasattr(self.config_obj, self.name) 

28 

29 @cached_property 

30 def default(self) -> Any: 

31 return getattr(self.config_obj, self.name, None) 

32 

33 @cached_property 

34 def short(self) -> str | None: 

35 return getattr(self.config_obj, self.name + self.SHORT_SUFFIX, None) 

36 

37 @cached_property 

38 def nargs(self) -> str | int | None: 

39 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default)) 

40 

41 if isinstance(type_, GenericAlias): 

42 args = get_args(type_) 

43 if args[1] is Ellipsis: 

44 return "*" 

45 else: 

46 return len(args) 

47 return None 

48 

49 @cached_property 

50 def type(self) -> Any: 

51 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default)) 

52 

53 if _is_optional(type_): 

54 type_ = get_args(type_)[0] 

55 elif isinstance(type_, GenericAlias): 

56 args = get_args(type_) 

57 type_ = args[0] 

58 

59 return type_ 

60 

61 @cached_property 

62 def names(self) -> list[str]: 

63 res = ["--" + self.name.lower().replace("_", "-")] 

64 if s := self.short: 

65 res.append("-" + s) 

66 return res 

67 

68 @cached_property 

69 def kwargs(self) -> dict[str, Any]: 

70 kwargs = dict( 

71 required=self.required, 

72 help=self.doc, 

73 env_var=self.name_to_env_var(), 

74 ) 

75 t = self.type 

76 if t is bool: 

77 if self.default: 

78 kwargs["action"] = "store_false" 

79 else: 

80 kwargs["action"] = "store_true" 

81 else: 

82 kwargs["type"] = t 

83 if self.required: 

84 kwargs["required"] = True 

85 else: 

86 kwargs["default"] = self.default 

87 if n := self.nargs: 

88 kwargs["nargs"] = n 

89 return kwargs 

90 

91 def name_to_env_var(self) -> str: 

92 return self.parent.ENV_VAR_PREFIX + self.name 

93 

94 

95class ConfigModule: 

96 ENV_VAR_PREFIX = "SLIDGE_" 

97 

98 def __init__( 

99 self, 

100 config_obj: Any, 

101 parser: Optional[configargparse.ArgumentParser] = None, 

102 skip_options: tuple[str, ...] = (), 

103 ) -> None: 

104 self.config_obj = config_obj 

105 if parser is None: 

106 parser = configargparse.ArgumentParser() 

107 self.parser = parser 

108 

109 self.skip_options = skip_options 

110 self.add_options_to_parser(skip_options) 

111 

112 def _list_options(self) -> set[str]: 

113 return { 

114 o 

115 for o in (set(dir(self.config_obj)) | set(get_type_hints(self.config_obj))) 

116 if o.upper() == o 

117 and not o.startswith("_") 

118 and "__" not in o 

119 and o.lower() not in self.skip_options 

120 } 

121 

122 def set_conf( 

123 self, argv: Optional[list[str]] = None 

124 ) -> tuple[configargparse.Namespace, list[str]]: 

125 if argv is not None: 

126 # this is ugly, but necessary because for plugin config, we used 

127 # remaining argv. 

128 # when using (a) .ini file(s), for bool options, we end-up with 

129 # remaining pseudo-argv such as --some-bool-opt=true when we really 

130 # should have just --some-bool-opt 

131 # TODO: get rid of configargparse and make this cleaner 

132 options_long = {o.name: o for o in self.options} 

133 no_explicit_bool = [] 

134 skip_next = False 

135 for a, aa in zip(argv, argv[1:] + [""]): 

136 if skip_next: 

137 skip_next = False 

138 continue 

139 force_keep = False 

140 if "=" in a: 

141 real_name, _value = a.split("=") 

142 opt: Optional[Option] = options_long.get( 

143 _argv_to_option_name(real_name) 

144 ) 

145 if opt and opt.type is bool: 

146 if opt.default: 

147 if _value in _TRUEISH or not _value: 

148 continue 

149 else: 

150 a = real_name 

151 force_keep = True 

152 else: 

153 if _value in _TRUEISH: 

154 a = real_name 

155 force_keep = True 

156 else: 

157 continue 

158 else: 

159 upper = _argv_to_option_name(a) 

160 opt = options_long.get(upper) 

161 if opt and opt.type is bool: 

162 if ( 

163 not aa.startswith("-") 

164 and _argv_to_option_name(aa) not in options_long 

165 ): 

166 log.debug("Removing %s from argv", aa) 

167 skip_next = True 

168 

169 if opt: 

170 if opt.type is bool: 

171 if force_keep or not opt.default: 

172 no_explicit_bool.append(a) 

173 else: 

174 no_explicit_bool.append(a) 

175 else: 

176 no_explicit_bool.append(a) 

177 log.debug("Removed boolean values from %s to %s", argv, no_explicit_bool) 

178 argv = no_explicit_bool 

179 

180 args, rest = self.parser.parse_known_args(argv) 

181 self.update_dynamic_defaults(args) 

182 for name in self._list_options(): 

183 value = getattr(args, name.lower()) 

184 log.debug("Setting '%s' to %r", name, value) 

185 setattr(self.config_obj, name, value) 

186 return args, rest 

187 

188 @cached_property 

189 def options(self) -> list[Option]: 

190 res = [] 

191 for opt in self._list_options(): 

192 res.append(Option(self, opt)) 

193 return res 

194 

195 def add_options_to_parser(self, skip_options: tuple[str, ...]) -> None: 

196 skip_options = tuple(o.lower() for o in skip_options) 

197 p = self.parser 

198 for o in sorted(self.options, key=lambda x: (not x.required, x.name)): 

199 if o.name.lower() in skip_options: 

200 continue 

201 p.add_argument(*o.names, **o.kwargs) 

202 

203 def update_dynamic_defaults(self, args: configargparse.Namespace) -> None: 

204 pass 

205 

206 

207def _is_optional(t: Any) -> bool: 

208 if get_origin(t) is Union: 

209 args = get_args(t) 

210 if len(args) == 2 and isinstance(None, args[1]): 

211 return True 

212 return False 

213 

214 

215def _argv_to_option_name(arg: str) -> str: 

216 return arg.upper().removeprefix("--").replace("-", "_") 

217 

218 

219_TRUEISH = {"true", "True", "1", "on", "enabled"} 

220 

221 

222log = logging.getLogger(__name__)