Last active
January 26, 2024 17:13
-
-
Save mikel-zhobro/0a75a2b6bb937da7128a2e0a24950734 to your computer and use it in GitHub Desktop.
One file dataclass argparse
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from argparse import ArgumentParser | |
| from dataclasses import dataclass, fields, Field, field | |
| class col: | |
| HEADER = "\033[95m" | |
| OKBLUE = "\033[94m" | |
| OKCYAN = "\033[96m" | |
| OKGREEN = "\033[92m" | |
| OKYELLOW = "\33[33m" | |
| WARNING = "\033[93m" | |
| FAIL = "\033[91m" | |
| ENDC = "\033[0m" | |
| BOLD = "\033[1m" | |
| UNDERLINE = "\033[4m" | |
| class DotDict(dict): | |
| """dot.notation access to dictionary attributes""" | |
| __getattr__ = dict.get | |
| __setattr__ = dict.__setitem__ # type: ignore | |
| __delattr__ = dict.__delitem__ # type: ignore | |
| def __init__(self, name="DICT", print_changes=True): | |
| super().__init__() | |
| self._name = name | |
| self._print_changes = print_changes | |
| def __getattribute__(self, __name: str) -> Any: | |
| if __name in super().keys(): | |
| return super().get(__name, None) | |
| return super().__getattribute__(__name) | |
| def __getstate__(self): | |
| return self.__dict__ | |
| def __setstate__(self, d): | |
| self.__dict__.update(d) | |
| def update(self, **kwarg): | |
| # Printing | |
| my_keys = [ | |
| key | |
| for key in kwarg.keys() | |
| if key in self.keys() and self[key] != kwarg[key] | |
| ] | |
| if len(my_keys) > 0: | |
| print( | |
| f"\n{col.UNDERLINE}Updating: {self.get('_name', 'X')}-args {col.ENDC}\n" | |
| ) | |
| print( | |
| f'{"key": ^15}: {col.FAIL} {"default": ^15} -> {col.OKGREEN} {"new":^15} {col.ENDC}' | |
| ) | |
| print( | |
| f'{"-" * 15: ^15} {col.FAIL} {"-" * 15: ^15} {col.OKGREEN} {"-" * 15:^15} {col.ENDC}' | |
| ) | |
| for key in my_keys: | |
| print( | |
| f"{key: ^15}: {col.FAIL} {str(self[key]): ^15} -> {col.OKGREEN} {str(kwarg[key]):^15} {col.ENDC}" | |
| ) | |
| # The important part | |
| for key in my_keys: | |
| setattr(self, key, kwarg[key]) | |
| return self | |
| def add_to_parser(self, parser: ArgumentParser): | |
| group = parser.add_argument_group(self._name) | |
| ls_fields: 'tuple[Field, ...]' = fields(self.__class__) | |
| for field in ls_fields: | |
| value = field.default_factory() if callable(field.default_factory) else field.default | |
| if not field.name.startswith("_"): | |
| t = type(value) | |
| nargs = '?' | |
| if type(value) is list: | |
| t = type(value[0]) if len(value) > 0 else eval(field.type.split('[')[1].split(']')[0]) | |
| nargs = '+' | |
| if t == bool: | |
| group.add_argument("--" + field.name, default=value, action="store_true") | |
| else: | |
| group.add_argument("--" + field.name, default=value, type=t, nargs=nargs) | |
| return self | |
| @dataclass | |
| class GeneralParams(DotDict): | |
| checkpoint_iterations : 'list[str]' = field(default_factory= lambda: []) | |
| test_iterations : 'list[int]' = field(default_factory= lambda: [7_000, 30_000]) | |
| save_iterations : 'list[int]' = field(default_factory= lambda: [7_000, 30_000]) | |
| def __post_init__(self): | |
| self._name = "GENERAL" | |
| @dataclass | |
| class ModelParams(DotDict): | |
| source_path : str = "data" | |
| model_path : str = "" | |
| def __post_init__(self): | |
| self._name = "MODEL" | |
| @dataclass | |
| class OptimizationParams(DotDict): | |
| iterations :float = 200 | |
| random_background :bool = False | |
| lambda_loss :float = 0.2 | |
| lr_init :float = 0.00016 | |
| lr_final :float = 0.0000016 | |
| lr_delay_mult :float = 0.01 | |
| def __post_init__(self): | |
| self._name = "OPTIM" | |
| parser = ArgumentParser() | |
| general_args = GeneralParams().add_to_parser(parser) | |
| model_args = ModelParams().add_to_parser(parser) | |
| optim_args = OptimizationParams().add_to_parser(parser) | |
| vargs = vars(parser.parse_args()) | |
| general_args.update(**vargs) | |
| model_args.update(**vargs) | |
| optim_args.update(**vargs) | |
| ## RUN | |
| # python tata.py --iterations 2000 --source_path my_new_path --test_iterations 1 23 4 | |
Author
Author
object.__setattr__(self, key, kwarg[key]) solves it
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Weird,
int/strfields of thedataclassdon't get updated but thedefault_factoryones do.