Coverage for slidge/util/conf.py: 98%
141 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-07 05:11 +0000
1import logging
2from functools import cached_property
3from types import GenericAlias
4from typing import Optional, Union, get_args, get_origin, get_type_hints
6import configargparse
9class Option:
10 DOC_SUFFIX = "__DOC"
11 DYNAMIC_DEFAULT_SUFFIX = "__DYNAMIC_DEFAULT"
12 SHORT_SUFFIX = "__SHORT"
14 def __init__(self, parent: "ConfigModule", name: str):
15 self.parent = parent
16 self.config_obj = parent.config_obj
17 self.name = name
19 @cached_property
20 def doc(self):
21 return getattr(self.config_obj, self.name + self.DOC_SUFFIX)
23 @cached_property
24 def required(self):
25 return not hasattr(
26 self.config_obj, self.name + self.DYNAMIC_DEFAULT_SUFFIX
27 ) and not hasattr(self.config_obj, self.name)
29 @cached_property
30 def default(self):
31 return getattr(self.config_obj, self.name, None)
33 @cached_property
34 def short(self):
35 return getattr(self.config_obj, self.name + self.SHORT_SUFFIX, None)
37 @cached_property
38 def nargs(self):
39 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default))
41 if isinstance(type_, GenericAlias):
42 args = get_args(type_)
43 if args[1] is Ellipsis:
44 return "*"
45 else:
46 return len(args)
48 @cached_property
49 def type(self):
50 type_ = get_type_hints(self.config_obj).get(self.name, type(self.default))
52 if _is_optional(type_):
53 type_ = get_args(type_)[0]
54 elif isinstance(type_, GenericAlias):
55 args = get_args(type_)
56 type_ = args[0]
58 return type_
60 @cached_property
61 def names(self):
62 res = ["--" + self.name.lower().replace("_", "-")]
63 if s := self.short:
64 res.append("-" + s)
65 return res
67 @cached_property
68 def kwargs(self):
69 kwargs = dict(
70 required=self.required,
71 help=self.doc,
72 env_var=self.name_to_env_var(),
73 )
74 t = self.type
75 if t is bool:
76 if self.default:
77 kwargs["action"] = "store_false"
78 else:
79 kwargs["action"] = "store_true"
80 else:
81 kwargs["type"] = t
82 if self.required:
83 kwargs["required"] = True
84 else:
85 kwargs["default"] = self.default
86 if n := self.nargs:
87 kwargs["nargs"] = n
88 return kwargs
90 def name_to_env_var(self):
91 return self.parent.ENV_VAR_PREFIX + self.name
94class ConfigModule:
95 ENV_VAR_PREFIX = "SLIDGE_"
97 def __init__(
98 self, config_obj, parser: Optional[configargparse.ArgumentParser] = None
99 ):
100 self.config_obj = config_obj
101 if parser is None:
102 parser = configargparse.ArgumentParser()
103 self.parser = parser
105 self.add_options_to_parser()
107 def _list_options(self):
108 return {
109 o
110 for o in (set(dir(self.config_obj)) | set(get_type_hints(self.config_obj)))
111 if o.upper() == o and not o.startswith("_") and "__" not in o
112 }
114 def set_conf(self, argv: Optional[list[str]] = None):
115 if argv is not None:
116 # this is ugly, but necessary because for plugin config, we used
117 # remaining argv.
118 # when using (a) .ini file(s), for bool options, we end-up with
119 # remaining pseudo-argv such as --some-bool-opt=true when we really
120 # should have just --some-bool-opt
121 # TODO: get rid of configargparse and make this cleaner
122 options_long = {o.name: o for o in self.options}
123 no_explicit_bool = []
124 skip_next = False
125 for a, aa in zip(argv, argv[1:] + [""]):
126 if skip_next:
127 skip_next = False
128 continue
129 force_keep = False
130 if "=" in a:
131 real_name, _value = a.split("=")
132 opt: Optional[Option] = options_long.get(
133 _argv_to_option_name(real_name)
134 )
135 if opt and opt.type is bool:
136 if opt.default:
137 if _value in _TRUEISH or not _value:
138 continue
139 else:
140 a = real_name
141 force_keep = True
142 else:
143 if _value in _TRUEISH:
144 a = real_name
145 force_keep = True
146 else:
147 continue
148 else:
149 upper = _argv_to_option_name(a)
150 opt = options_long.get(upper)
151 if opt and opt.type is bool:
152 if _argv_to_option_name(aa) not in options_long:
153 log.debug("Removing %s from argv", aa)
154 skip_next = True
156 if opt:
157 if opt.type is bool:
158 if force_keep or not opt.default:
159 no_explicit_bool.append(a)
160 else:
161 no_explicit_bool.append(a)
162 else:
163 no_explicit_bool.append(a)
164 log.debug("Removed boolean values from %s to %s", argv, no_explicit_bool)
165 argv = no_explicit_bool
167 args, rest = self.parser.parse_known_args(argv)
168 self.update_dynamic_defaults(args)
169 for name in self._list_options():
170 value = getattr(args, name.lower())
171 log.debug("Setting '%s' to %r", name, value)
172 setattr(self.config_obj, name, value)
173 return args, rest
175 @cached_property
176 def options(self) -> list[Option]:
177 res = []
178 for opt in self._list_options():
179 res.append(Option(self, opt))
180 return res
182 def add_options_to_parser(self):
183 p = self.parser
184 for o in sorted(self.options, key=lambda x: (not x.required, x.name)):
185 p.add_argument(*o.names, **o.kwargs)
187 def update_dynamic_defaults(self, args):
188 pass
191def _is_optional(t):
192 if get_origin(t) is Union:
193 args = get_args(t)
194 if len(args) == 2 and isinstance(None, args[1]):
195 return True
196 return False
199def _argv_to_option_name(arg: str):
200 return arg.upper().removeprefix("--").replace("-", "_")
203_TRUEISH = {"true", "True", "1", "on", "enabled"}
206log = logging.getLogger(__name__)