etna
etna copied to clipboard
[BUG] Cannot train DeepARModel with a custom Loss function
🐛 Bug Report
I have a Time-Series Regressor model with timestamp
, segment
, target
and other exogenious features.
The model runs in the Kaggle notebook.
pf_transform = PytorchForecastingTransform(
max_encoder_length=HORIZON,
max_prediction_length=HORIZON,
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["target"],
time_varying_known_categoricals=cat_cols,
target_normalizer=GroupNormalizer(groups=["segment"]),
)
model = DeepARModel(max_epochs=5, learning_rate=[0.01], gpus=1, batch_size=32)
pipeline = Pipeline(model=model, transforms=[pf_transform], horizon=HORIZON)
etna_ts = TSDataset(df=ts_labels, freq="20T", df_exog=ts_exog, known_future='all')
metrics, forecasts, _ = pipeline.backtest(etna_ts, n_folds=3, n_jobs=-1, metrics=[etna.metrics.MAE()], aggregate_metrics=True)
The upper code works well.
When I change it to the following it also works well:
model = DeepARModel(max_epochs=5, learning_rate=[0.01], loss=pytorch_forecasting.metrics.NormalDistributionLoss(), gpus=1, batch_size=32)
But when I change it to the following it fails with an error:
model = DeepARModel(max_epochs=5, learning_rate=[0.01], loss=pytorch_forecasting.metrics.MAE(), gpus=1, batch_size=32)
The error is:
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/multiprocessing/pool.py", line 121, in worker
result = (True, func(*args, **kwds))
File "/opt/conda/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 595, in __call__
return self.func(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/joblib/parallel.py", line 263, in __call__
for func, args, kwargs in self.items]
File "/opt/conda/lib/python3.7/site-packages/joblib/parallel.py", line 263, in <listcomp>
for func, args, kwargs in self.items]
File "/opt/conda/lib/python3.7/site-packages/etna/pipeline/base.py", line 391, in _run_fold
pipeline.fit(ts=train)
File "/opt/conda/lib/python3.7/site-packages/etna/pipeline/pipeline.py", line 48, in fit
self.model.fit(self.ts)
File "/opt/conda/lib/python3.7/site-packages/etna/models/base.py", line 45, in wrapper
result = f(self, *args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/etna/models/nn/deepar.py", line 153, in fit
self.model = self._from_dataset(pf_transform.pf_dataset_train)
File "/opt/conda/lib/python3.7/site-packages/etna/models/nn/deepar.py", line 123, in _from_dataset
loss=self.loss,
File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/deepar/__init__.py", line 188, in from_dataset
dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs
File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 1474, in from_dataset
return super().from_dataset(dataset, **new_kwargs)
File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 1797, in from_dataset
return super().from_dataset(dataset, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 987, in from_dataset
net = cls(**kwargs)
File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/deepar/__init__.py", line 139, in __init__
), "number of targets should be equivalent to number of loss metrics"
AssertionError: number of targets should be equivalent to number of loss metrics
In all cases I have these warnings:
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (NormalDistributionLoss). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
[Parallel(n_jobs=-1)]: Using backend MultiprocessingBackend with 2 concurrent workers.
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (SMAPE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MAE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (RMSE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MAPE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MASE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:152: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=False)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=False)`.
f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:133: UserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (SMAPE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MAE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (RMSE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MAPE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class (MASE). The property determines if `update` by
default needs access to the full metric state. If this is not the case, significant speedups can be
achieved and we recommend setting this to `False`.
We provide an checking function
`from torchmetrics.utilities import check_forward_no_full_state`
that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
default for now) or if `full_state_update=False` can be used safely.
warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
Expected behavior
The loss=pytorch_forecasting.metrics.MAE()
should work well or return more meaningful error.
How To Reproduce
The source code is added to the "Bug Report" section.
Environment
Python: 3.7 (Kaggle doesn't support python 3.8) Etna: 1.11.1 pytorch_forecasting: 0.10.1
Additional context
No response
Checklist
- [X] Bug appears at the latest library version
Hi, @mirik123 !
Thank you for your bug report.
We have no explicit type checking here but there is typing support in the signature DeepARModel(..., loss: Optional["DistributionLoss"], ...)
.
So you can use DistributionLoss
based losses only.
I guess typing annotation is the only guarantee we could make for now.
We'll try to design general mechanism of explicit type checking for all models and signatures. If you have any ideas you're welcome
Hi @martins0n
Thank you for an update.
What about all these warnings produced by pipeline.backtest
. Is it OK?
pytorch_lightning: v1.6.5
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of 'nn.Module' and is already saved during checkpointing. It is recommended to ignore them using 'self.save_hyperparameters(ignore=['logging_metrics'])'.
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:152: LightningDeprecationWarning: Setting 'Trainer(checkpoint_callback=False)' is deprecated in v1.5 and will be removed in v1.7. Please consider using 'Trainer(enable_checkpointing=False)'.
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:133: UserWarning: You defined a 'validation_step' but have no 'val_dataloader'. Skipping val loop.
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:486: RuntimeWarning: ReduceLROnPlateau conditioned on metric val_loss which is not available but strict is set to 'False'. Skipping learning rate update.
@mirik123 Warnings are ok. They don't affect final results. But we're going to resolve them all if it would be possible
checkpoint_callback=False
- could cause errors in case of using pytorch_lightning>=1.7.0
. We have fixed this here #866 and it would not raise warnings in the next release.