omegaconf
omegaconf copied to clipboard
Interpolation to index of custom interpolation fails, depending on order of structured config fields
Describe the bug When interpolation B tries to access index i of a custom interpolation A which resolves to a list, then that index i cannot be accessed if the field/key of interpolation A doesn't appear prior to the field/key of interpolation B in the defining structured config.
To Reproduce
In the following, interpolation A is the one from the input_size
key, interpolation B is n_channels
and i = 0.
from dataclasses import dataclass
from typing import List
from omegaconf import SI, OmegaConf # , ListConfig
def get_input_size():
return [3, 224, 224] # -> Works with TrainingConfigWhichWorks, not TrainingConfigWhichFails.
# return ListConfig([3, 224, 224]) # -> Works with both TrainingConfigs.
@dataclass
class TrainingConfigWhichWorks:
input_size: List[int] = SI("${get_input_size:}")
n_channels: int = SI("${.input_size.0}")
@dataclass
class TrainingConfigWhichFails:
n_channels: float = SI("${.input_size.0}")
input_size: List[float] = SI("${get_input_size:}")
OmegaConf.register_new_resolver("get_input_size", get_input_size)
config = OmegaConf.create(TrainingConfigWhichWorks)
OmegaConf.resolve(config) # -> Works
print(config) # -> {'input_size': [3, 224, 224], 'n_channels': 3}
config = OmegaConf.create(TrainingConfigWhichFails)
OmegaConf.resolve(config) # -> FAILS if get_input_size returns a list, but works if it is a ListConfig.
Expected behavior
Both TrainingConfigWhichWorks
and TrainingConfigWhichFails
should work.
This bug also has implications on hydra:
(a) the success or failure of instantiating objects becomes order dependent and
(b) some sub-configs cannot be called independently of each other.
See the slightly longer example below for an illustration with hydra.
Additional context
- [ ] OmegaConf version: 2.1.1
- [ ] Python version: 3.7.12
- [ ] Operating system: MacOS 12.2.1
Optional: Example of implications for hydra
The following example shows how the previous bug of omegaconf can affect hydra. It illustrates in particular points (a) and (b) mentioned above.
from dataclasses import dataclass
from typing import List
import hydra
from hydra import utils as hu
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf # ListConfig
def get_input_size():
return [3, 224, 224]
@dataclass
class TransformConfig:
input_size: List[int] = "${get_input_size:}"
@dataclass
class DatasetConfig:
n_channels: int = "${transform.input_size.0}"
@dataclass
class TrainingConfigWhichWorks:
transform: TransformConfig = TransformConfig()
dataset: DatasetConfig = DatasetConfig()
@dataclass
class TrainingConfigWhichFails:
dataset: DatasetConfig = DatasetConfig()
transform: TransformConfig = TransformConfig()
OmegaConf.register_new_resolver("get_input_size", get_input_size)
config_store = ConfigStore.instance()
config_store.store(name="training_config_which_works", node=TrainingConfigWhichWorks)
config_store.store(name="training_config_which_fails", node=TrainingConfigWhichFails)
@hydra.main(config_name="training_config_which_works", config_path=None)
def main_which_works(config):
cfg = hu.instantiate(config)
print("Config from main_which_works:", cfg)
transform_cfg = hu.instantiate(config.transform)
print("TransformConfig from main_which_works:", transform_cfg)
# hu.instantiate(config.dataset) # -> This would fail.
@hydra.main(config_name="training_config_which_fails", config_path=None)
def main_which_fails(config):
hu.instantiate(config) # Fails
hu.instantiate(config.dataset) # Also fails
if __name__ == "__main__":
main_which_works()
main_which_fails()
Interesting. Thanks for the report, @cjsg.
Returning a ListConfig from get_input_size
is probably the safest workaround for now.
Yes, using ListConfig works. Another work-around is to define a second custom resolver that returns an index (or range) of a given Sequence, i.e. something like:
def slice(l, idx):
return l[idx]
OmegaConf.register_new_resolver("slice", slice)
@dataclass
class TrainingConfig:
n_channels: float = SI("${slice:${.input_size},0}")
input_size: List[float] = SI("${get_input_size:}")