Skip to content

Instantly share code, notes, and snippets.

@mikel-zhobro
Last active January 26, 2024 17:13
Show Gist options
  • Select an option

  • Save mikel-zhobro/0a75a2b6bb937da7128a2e0a24950734 to your computer and use it in GitHub Desktop.

Select an option

Save mikel-zhobro/0a75a2b6bb937da7128a2e0a24950734 to your computer and use it in GitHub Desktop.
One file dataclass argparse
from argparse import ArgumentParser
from dataclasses import dataclass, fields, Field, field
class col:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
OKYELLOW = "\33[33m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
class DotDict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__ # type: ignore
__delattr__ = dict.__delitem__ # type: ignore
def __init__(self, name="DICT", print_changes=True):
super().__init__()
self._name = name
self._print_changes = print_changes
def __getattribute__(self, __name: str) -> Any:
if __name in super().keys():
return super().get(__name, None)
return super().__getattribute__(__name)
def __getstate__(self):
return self.__dict__
def __setstate__(self, d):
self.__dict__.update(d)
def update(self, **kwarg):
# Printing
my_keys = [
key
for key in kwarg.keys()
if key in self.keys() and self[key] != kwarg[key]
]
if len(my_keys) > 0:
print(
f"\n{col.UNDERLINE}Updating: {self.get('_name', 'X')}-args {col.ENDC}\n"
)
print(
f'{"key": ^15}: {col.FAIL} {"default": ^15} -> {col.OKGREEN} {"new":^15} {col.ENDC}'
)
print(
f'{"-" * 15: ^15} {col.FAIL} {"-" * 15: ^15} {col.OKGREEN} {"-" * 15:^15} {col.ENDC}'
)
for key in my_keys:
print(
f"{key: ^15}: {col.FAIL} {str(self[key]): ^15} -> {col.OKGREEN} {str(kwarg[key]):^15} {col.ENDC}"
)
# The important part
for key in my_keys:
setattr(self, key, kwarg[key])
return self
def add_to_parser(self, parser: ArgumentParser):
group = parser.add_argument_group(self._name)
ls_fields: 'tuple[Field, ...]' = fields(self.__class__)
for field in ls_fields:
value = field.default_factory() if callable(field.default_factory) else field.default
if not field.name.startswith("_"):
t = type(value)
nargs = '?'
if type(value) is list:
t = type(value[0]) if len(value) > 0 else eval(field.type.split('[')[1].split(']')[0])
nargs = '+'
if t == bool:
group.add_argument("--" + field.name, default=value, action="store_true")
else:
group.add_argument("--" + field.name, default=value, type=t, nargs=nargs)
return self
@dataclass
class GeneralParams(DotDict):
checkpoint_iterations : 'list[str]' = field(default_factory= lambda: [])
test_iterations : 'list[int]' = field(default_factory= lambda: [7_000, 30_000])
save_iterations : 'list[int]' = field(default_factory= lambda: [7_000, 30_000])
def __post_init__(self):
self._name = "GENERAL"
@dataclass
class ModelParams(DotDict):
source_path : str = "data"
model_path : str = ""
def __post_init__(self):
self._name = "MODEL"
@dataclass
class OptimizationParams(DotDict):
iterations :float = 200
random_background :bool = False
lambda_loss :float = 0.2
lr_init :float = 0.00016
lr_final :float = 0.0000016
lr_delay_mult :float = 0.01
def __post_init__(self):
self._name = "OPTIM"
parser = ArgumentParser()
general_args = GeneralParams().add_to_parser(parser)
model_args = ModelParams().add_to_parser(parser)
optim_args = OptimizationParams().add_to_parser(parser)
vargs = vars(parser.parse_args())
general_args.update(**vargs)
model_args.update(**vargs)
optim_args.update(**vargs)
## RUN
# python tata.py --iterations 2000 --source_path my_new_path --test_iterations 1 23 4
@mikel-zhobro
Copy link
Author

Weird, int/str fields of the dataclass don't get updated but the default_factory ones do.

@mikel-zhobro
Copy link
Author

object.__setattr__(self, key, kwarg[key]) solves it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment