armory
armory copied to clipboard
Some model training parameters in a configuration are not being used
Some parameters are not being drawn from the configuration file, but hard coded instead. For example,
- input shape: https://github.com/twosixlabs/armory/blob/53a7ccbd122ce3e05118416ba0d5442dee737fb9/armory/baseline_models/pytorch/resnet18.py#L56
- input shape: https://github.com/twosixlabs/armory/blob/53a7ccbd122ce3e05118416ba0d5442dee737fb9/armory/baseline_models/pytorch/micronnet_gtsrb.py#L76
- learning rate: https://github.com/twosixlabs/armory/blob/53a7ccbd122ce3e05118416ba0d5442dee737fb9/armory/baseline_models/pytorch/micronnet_gtsrb.py#L74
- learning rate: https://github.com/twosixlabs/armory/blob/53a7ccbd122ce3e05118416ba0d5442dee737fb9/armory/baseline_models/pytorch/cifar.py#L60
Maybe, a desired behavior is that a certain parameter is not used but specified, it should warn or error.
@davidslater ^^ I wanted to get this in your radar. It may lead to errors in the future
Some of those mentioned above are actually using the config wrapper_kwargs, if they exist. For example, in the resnet18 one: https://github.com/twosixlabs/armory/blob/53a7ccbd122ce3e05118416ba0d5442dee737fb9/armory/baseline_models/pytorch/resnet18.py#L55-L57
input_shape=wrapper_kwargs.pop(
"input_shape", (224, 224, 3)
), # default to imagenet shape
Many of the others are indeed hard coded. However, if they are added to the wrapper_kwargs
in the config, they will error here (or similar places in the other files):
https://github.com/twosixlabs/armory/blob/53a7ccbd122ce3e05118416ba0d5442dee737fb9/armory/baseline_models/pytorch/micronnet_gtsrb.py#L79
**wrapper_kwargs
For instance, if I add this line to the cifar10_baseline.json
:
"wrapper_kwargs": {
"input_shape": [64, 64, 3]
}
I will get this error when running armory:
armory run scenario_configs/cifar10_baseline.json --check
...
File "/workspace/armory/baseline_models/pytorch/cifar.py", line 57, in get_art_model
wrapped_model = PyTorchClassifier(
└ <class 'art.estimators.classification.pytorch.PyTorchClassifier'>
TypeError:
InputFilter object got multiple values for keyword argument 'input_shape'