class_weights cannot be passed via config file as a tensor is expected
Description
Using the Lightning CLI we can train the SemanticSegmentationTask, but cannot use class_weights without an error. Solution is to accept a list if int in addition to tensor
Steps to reproduce
In Lightning CLI Yaml:
model:
class_path: SemanticSegmentationTask
init_args:
model: unet
backbone: resnet50
weights: null
lr: 0.001
in_channels: 6
num_classes: 2
class_weights:
- 1
- 50
Will result in
Does not validate against any of the Union subtypes
Subtypes: (<class 'torch.Tensor'>, <class 'NoneType'>)
Errors:
- Not a valid subclass of Tensor
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)
- Expected a <class 'NoneType'>
Given value type: <class 'list'>
Given value: [1, 50]
Version
main
Care to make a PR to accept a list and convert to a tensor? If not then I can take it on this weekend.
You will get to it way before me!
For a bit of history, I added this in #1221 and it initially only supported lists. In #1413, @ntw-au modified this to support lists, numpy arrays, and torch tensors. Then in #1541, I modified it to only accept torch tensors. I agree we need a way to support class_weights in a YAML file (and preferably also on the command line). If omegaconf supports this, we could also easily enable omegaconf as a parser: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced_2.html#enable-variable-interpolation.
If you want to use it with hydra.utils.instantiate and omegaconf you would only need to do the following:
class_weights:
_target_: torch.tensor
data: [0.5, 0.5]
I haven't looked at the Lightning CLI in awhile but I wonder if it supports recursive instantiation like
class_weights:
class_path: torch.tensor
init_args:
data: [0.5, 0.5]