etna icon indicating copy to clipboard operation
etna copied to clipboard

[BUG] Cannot train DeepARModel with a custom Loss function

Open mirik123 opened this issue 2 years ago • 1 comments

🐛 Bug Report

I have a Time-Series Regressor model with timestamp, segment, target and other exogenious features. The model runs in the Kaggle notebook.

pf_transform = PytorchForecastingTransform(
    max_encoder_length=HORIZON,
    max_prediction_length=HORIZON,
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["target"],
    time_varying_known_categoricals=cat_cols,
    target_normalizer=GroupNormalizer(groups=["segment"]),
)

model = DeepARModel(max_epochs=5, learning_rate=[0.01], gpus=1, batch_size=32)
pipeline = Pipeline(model=model, transforms=[pf_transform], horizon=HORIZON)
etna_ts = TSDataset(df=ts_labels, freq="20T", df_exog=ts_exog, known_future='all')
metrics, forecasts, _ = pipeline.backtest(etna_ts, n_folds=3, n_jobs=-1, metrics=[etna.metrics.MAE()], aggregate_metrics=True)

The upper code works well. When I change it to the following it also works well: model = DeepARModel(max_epochs=5, learning_rate=[0.01], loss=pytorch_forecasting.metrics.NormalDistributionLoss(), gpus=1, batch_size=32)

But when I change it to the following it fails with an error: model = DeepARModel(max_epochs=5, learning_rate=[0.01], loss=pytorch_forecasting.metrics.MAE(), gpus=1, batch_size=32) The error is:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
  File "/opt/conda/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 595, in __call__
    return self.func(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/joblib/parallel.py", line 263, in __call__
    for func, args, kwargs in self.items]
  File "/opt/conda/lib/python3.7/site-packages/joblib/parallel.py", line 263, in <listcomp>
    for func, args, kwargs in self.items]
  File "/opt/conda/lib/python3.7/site-packages/etna/pipeline/base.py", line 391, in _run_fold
    pipeline.fit(ts=train)
  File "/opt/conda/lib/python3.7/site-packages/etna/pipeline/pipeline.py", line 48, in fit
    self.model.fit(self.ts)
  File "/opt/conda/lib/python3.7/site-packages/etna/models/base.py", line 45, in wrapper
    result = f(self, *args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/etna/models/nn/deepar.py", line 153, in fit
    self.model = self._from_dataset(pf_transform.pf_dataset_train)
  File "/opt/conda/lib/python3.7/site-packages/etna/models/nn/deepar.py", line 123, in _from_dataset
    loss=self.loss,
  File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/deepar/__init__.py", line 188, in from_dataset
    dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs
  File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 1474, in from_dataset
    return super().from_dataset(dataset, **new_kwargs)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 1797, in from_dataset
    return super().from_dataset(dataset, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 987, in from_dataset
    net = cls(**kwargs)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_forecasting/models/deepar/__init__.py", line 139, in __init__
    ), "number of targets should be equivalent to number of loss metrics"
AssertionError: number of targets should be equivalent to number of loss metrics

In all cases I have these warnings:

/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (NormalDistributionLoss). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
[Parallel(n_jobs=-1)]: Using backend MultiprocessingBackend with 2 concurrent workers.
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (SMAPE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MAE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (RMSE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MAPE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MASE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:152: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=False)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=False)`.
  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:133: UserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (SMAPE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MAE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (RMSE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MAPE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (MASE). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
  warnings.warn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."

Expected behavior

The loss=pytorch_forecasting.metrics.MAE() should work well or return more meaningful error.

How To Reproduce

The source code is added to the "Bug Report" section.

Environment

Python: 3.7 (Kaggle doesn't support python 3.8) Etna: 1.11.1 pytorch_forecasting: 0.10.1

Additional context

No response

Checklist

  • [X] Bug appears at the latest library version

mirik123 avatar Aug 12 '22 10:08 mirik123

Hi, @mirik123 !

Thank you for your bug report.

We have no explicit type checking here but there is typing support in the signature DeepARModel(..., loss: Optional["DistributionLoss"], ...) . So you can use DistributionLoss based losses only.

I guess typing annotation is the only guarantee we could make for now.

We'll try to design general mechanism of explicit type checking for all models and signatures. If you have any ideas you're welcome

martins0n avatar Aug 15 '22 08:08 martins0n

Hi @martins0n

Thank you for an update. What about all these warnings produced by pipeline.backtest. Is it OK? pytorch_lightning: v1.6.5

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of 'nn.Module' and is already saved during checkpointing. It is recommended to ignore them using 'self.save_hyperparameters(ignore=['logging_metrics'])'.

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:152: LightningDeprecationWarning: Setting 'Trainer(checkpoint_callback=False)' is deprecated in v1.5 and will be removed in v1.7. Please consider using 'Trainer(enable_checkpointing=False)'.

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:133: UserWarning: You defined a 'validation_step' but have no 'val_dataloader'. Skipping val loop.

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:486: RuntimeWarning: ReduceLROnPlateau conditioned on metric val_loss which is not available but strict is set to 'False'. Skipping learning rate update.

mirik123 avatar Aug 19 '22 15:08 mirik123

@mirik123 Warnings are ok. They don't affect final results. But we're going to resolve them all if it would be possible

checkpoint_callback=False - could cause errors in case of using pytorch_lightning>=1.7.0. We have fixed this here #866 and it would not raise warnings in the next release.

martins0n avatar Aug 23 '22 09:08 martins0n