SimpleParsing icon indicating copy to clipboard operation
SimpleParsing copied to clipboard

Add support for a config file with optional overriding

Open yuvval opened this issue 4 years ago • 7 comments

Is your feature request related to a problem? Please describe. I would like to be able to load the args configuration from a json/yaml file, and at the same time override some values from the command line

Describe the solution you'd like e.g. python my_script.py --config=conf.yaml --lr=1e-2

Describe alternatives you've considered parsing the commandline argv for the overriding args - and manually set them - however this makes a lot it difficult - when dataclasses are nested

yuvval avatar Feb 23 '21 09:02 yuvval

Hi @yuvval , sorry for not getting back to you sooner!

hmm, that's interesting. You can already achieve what you want without the need for a new feature: For instance, you could do something like this (which, I admit, is a bit long, but gets the job done!):

from dataclasses import dataclass, fields, is_dataclass, asdict
from simple_parsing import ArgumentParser, Serializable
from typing import Dict, Union, Any
from simple_parsing.utils import dict_union


@dataclass
class HParams(Serializable):
    lr: float = 0.01
    foo: str = "hello"

    def replace(self, **new_params):
        new_hp_dict = dict_union(asdict(self), new_params, recurse=True)
        new_hp = type(self).from_dict(new_hp_dict)
        return new_hp


def differering_values(target: HParams, reference: HParams) -> Dict[str, Union[Any, Dict]]:
    """ Given a dataclass, and a 'reference' dataclass, returns a possibly nested dict
    of all the values that are different in `value` compared to in `reference`.
    """
    non_default_values = {}
    for field in fields(target):
        name = field.name
        target_value = getattr(target, name)
        reference_value = getattr(reference, name)
        if target_value == reference_value:
            continue
        if is_dataclass(target_value) and is_dataclass(reference_value):
            # Recurse in the case of unequal dataclasses.
            non_default_values[name] = differering_values(target_value, reference_value)
        else:
            non_default_values[name] = target_value
    return non_default_values


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_arguments(HParams, "hparams")
    parser.add_argument("--config_path", default="")
    args = parser.parse_args()

    config_path: str = args.config_path
    hparams: HParams = args.hparams

    default_hparams = HParams()

    if config_path:
        # IDEA: Create a new HParams object whose values are based are those from the
        # config, plus any hparam value parsed from the command-line different than the default.
        config_hparams = HParams.load(config_path)
        new_kwargs = differering_values(hparams, default_hparams)
        hparams = config_hparams.replace(**new_kwargs)

    print(f"Hparams: {hparams}")

Then, assuming the file.yaml file contains this:

foo: bar
lr: 0.05

You can then get what you want:

$ python test/test_issue45.py
Hparams: HParams(lr=0.01, foo='hello')
$ python test/test_issue45.py --config_path file.yaml
Hparams: HParams(lr=0.05, foo='bar')
$ python test/test_issue45.py --config_path file.yaml --foo bob
Hparams: HParams(lr=0.05, foo='bob')
$ python test/test_issue45.py --config_path file.yaml --lr 4.123
Hparams: HParams(lr=4.123, foo='bar')

Hope this helps! I'll close the issue for now, let me know if you have any other questions.

lebrice avatar Apr 27 '21 03:04 lebrice

Here's how I do it. Perhaps this would make a good example.

import enum
from typing import *
from pathlib import Path
from dataclasses import dataclass

from simple_parsing import ArgumentParser, field, choice
from simple_parsing.helpers import Serializable

@dataclass
class CfgFileConfig:
    """Config file args"""
    load_config: Optional[Path] = None          # load config file
    save_config: Optional[Path] = None          # save config to specified file

@dataclass
class MyPreferences(Serializable):
    light_switch: bool = True # turn on light if true
    
def main(args=None):

    cfgfile_parser = ArgumentParser(add_help=False)
    cfgfile_parser.add_arguments(CfgFileConfig, dest="cfgfile")
    cfgfile_args, rest = cfgfile_parser.parse_known_args()
    
    cfgfile: CfgFileConfig = cfgfile_args.cfgfile
    
    file_config: Optional[Config] = None
    if cfgfile.load_config is not None:
        file_config = MyPreferences.load(cfgfile.load_config)

    parser = ArgumentParser()
    
    # add cfgfile args so they appear in the help message
    parser.add_arguments(CfgFileConfig, dest="cfgfile") 
    parser.add_arguments(MyPreferences, dest="my_preferences", default=file_config)
    
    args = parser.parse_args()
        
    prefs: MyPreferences = args.my_preferences
    print(prefs)

    if cfgfile.save_config is not None:
        prefs.save(cfgfile.save_config, indent=4)
        
if __name__ == '__main__':
    main()

stevebyan avatar Feb 10 '22 17:02 stevebyan

A downside to the diff approach is that it fails when a default value is passed into the command line. The user will expect the command line argument to supersede the config file's value. But it won't because it just so happens to equal the dataclasses' default value. I think the best workaround is to parse twice.

psirenny avatar Mar 25 '22 15:03 psirenny

It'd be really nice to make it a built-in option like e.g. in Pyrallis or Clout.

andrey-klochkov-liftoff avatar May 25 '22 00:05 andrey-klochkov-liftoff

Sure thing, this makes sense. I'll take a look.

lebrice avatar May 25 '22 13:05 lebrice

Any news?

Yevgnen avatar Jul 25 '22 03:07 Yevgnen

No news atm, I've been busy with other stuff. I'll push that onto my stack of TODOs, hopefully I'll have something to show for it in a week or two!

lebrice avatar Jul 26 '22 17:07 lebrice

Ok I've added some better support for this in #158. Let me know what you think! :)

lebrice avatar Aug 19 '22 23:08 lebrice