-
-
Save seba-1511/ac7545b5b87c057494f45dedfcbf0b10 to your computer and use it in GitHub Desktop.
A principled way to have config dictionary that can be saved/restored, support pre-defined type and default values, and support cmd parsing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import yaml | |
| import argparse | |
| class _ConfigDict(dict): | |
| """ A subclass of dict that supports required args | |
| """ | |
| # Change this to customize it | |
| fields = [] | |
| @classmethod | |
| def get_cmd_parser(cls, parser=None): | |
| """ Return a argparse.ArgumentParser or add argument to existing one """ | |
| if parser is None: | |
| parser = argparse.ArgumentParser() | |
| for field in cls.fields: | |
| parser.add_argument('-' + field[0], type=field[1], default=field[2], | |
| help=field[3]) | |
| return parser | |
| def __init__(self, **kwargs): | |
| """ set up fields """ | |
| self.validate_config_dict(kwargs) | |
| super(_ConfigDict, self).__init__(**kwargs) | |
| @classmethod | |
| def from_yaml(cls, yaml_file): | |
| """ Load from yaml file """ | |
| with open(yaml_file, 'r') as stream: | |
| try: | |
| return cls(**yaml.load(stream)) | |
| except yaml.YAMLError as exc: | |
| print(exc) | |
| def to_yaml(self, yaml_file): | |
| """ Save to yaml file """ | |
| data = {k: v for k,v in self.items()} | |
| with open(yaml_file, 'w') as out: | |
| yaml.dump(data, out, default_flow_style=False) | |
| @classmethod | |
| def validate_config_dict(cls, config_dict): | |
| """ Given a dict, verify whether it has all the fields """ | |
| required_field = [field[0] for field in cls.fields] | |
| kwargs_field = list(config_dict.keys()) | |
| missing_fields = set(required_field) - set(kwargs_field) | |
| assert len(missing_fields) == 0, 'Missing fields for game config: ' + str(missing_fields) | |
| def display(self): | |
| """ Display the config """ | |
| for k, v in self.items(): | |
| print('{}:{}'.format(k, v)) | |
| def config_dict(name, fields): | |
| """ Return a class that has spefied fields """ | |
| return type(name, (_ConfigDict,), {'fields': fields}) | |
| if __name__ == '__main__': | |
| # Create customizable config dict class | |
| my_fields = [('lr', float, 0.01, 'learning rate'), | |
| ('mom', float, 0.9, 'sgd momentum')] | |
| MyConfigDict = config_dict('MyConfig', my_fields) | |
| parser = MyConfigDict.get_cmd_parser() | |
| args = parser.parse_args() | |
| # Create config dict | |
| config = MyConfigDict(**args.__dict__) | |
| print('old config') | |
| config.display() | |
| # save and load | |
| config.to_yaml('config.yaml') | |
| # Load use class | |
| config_new = MyConfigDict.from_yaml('config.yaml') | |
| print('new config') | |
| config_new.display() | |
| # Load use object | |
| config_new_2 = config.from_yaml('config.yaml') | |
| print('new config 2') | |
| config_new_2.display() | |
| # Create by hand | |
| config_by_args = MyConfigDict(lr=5, mom=0.9) | |
| print('create config by directly input args') | |
| config_by_args.display() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment