etna icon indicating copy to clipboard operation
etna copied to clipboard

[BUG] Make `DeepARModel` deterministic

Open Mr-Geekman opened this issue 2 years ago • 0 comments

🐛 Bug Report

Currently, DeepARModel during each forecast predicts different values. It seems like it is not a general problem for neural nets: TFTModel predicts the same values each time.

We can probably make forecast deterministic if we fix n_samples parameter for prediction in init of DeepARModel. For deterministic behaviour n_samples should be set to None.

Expected behavior

Each time DeepARModel predicts the same values.

After fixing you should

  1. Init DeepARModel in test_inference as deterministic.
  2. Remove torch random initialization from test_inference._test_forecast_out_sample_prefix.

How To Reproduce

from pandas.util.testing import assert_frame_equal
from pytorch_forecasting.data import GroupNormalizer

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.models.nn import DeepARModel
from etna.pipeline import Pipeline
from etna.transforms import PytorchForecastingTransform


def main():
    # load data
    df = generate_ar_df(periods=100, n_segments=3, start_time="2020-01-01", freq="D", random_seed=0)
    ts = TSDataset(df=TSDataset.to_dataset(df), freq="D")

    # fit pipeline
    model = DeepARModel(max_epochs=5, learning_rate=[0.01])
    pf_transform = PytorchForecastingTransform(
        max_encoder_length=5,
        max_prediction_length=5,
        time_varying_known_reals=["time_idx"],
        time_varying_unknown_reals=["target"],
        target_normalizer=GroupNormalizer(groups=["segment"]),
    )

    pipeline = Pipeline(model=model, horizon=5, transforms=[pf_transform])
    pipeline.fit(ts)

    # forecast
    ts_forecast_1 = pipeline.forecast()
    ts_forecast_2 = pipeline.forecast()
    assert_frame_equal(ts_forecast_1.to_pandas(), ts_forecast_2.to_pandas())


if __name__ == "__main__":
    main()

Environment

No response

Additional context

No response

Checklist

  • [X] Bug appears at the latest library version

Mr-Geekman avatar Jul 15 '22 09:07 Mr-Geekman