neuralforecast icon indicating copy to clipboard operation
neuralforecast copied to clipboard

Significant gap in forecasts from v2.0.1 to v3.0.0 when using DistributionLoss

Open Antoine-Schwartz opened this issue 7 months ago • 2 comments
trafficstars

What happened + What you expected to happen

I suspect a bug, or at least an unintentional change introduced by the 3.0.0 release, which impacts forecasts when using a DistributionLoss. Sometimes it improves performance, sometimes not, depending on the nature of the data.

On the other hand, there's no change in results when using BasePointLoss from one version to the other.

Versions / Dependencies

neuralforecast==2.0.1 or neuralforecast==3.0.0

Reproduction script

import pandas as pd
import numpy as np 
import itertools

from neuralforecast import NeuralForecast
from neuralforecast.models import DeepAR, TFT, NHITS
from neuralforecast.losses.pytorch import MAE, DistributionLoss
from neuralforecast.utils import AirPassengersPanel

Y_df = AirPassengersPanel

nf = NeuralForecast(
    models=[
        eval(model)(
            h=12,
            input_size=48,
            max_steps=100,
            scaler_type="standard",
            loss=loss,
            alias=f"{model}-{loss.distribution}" if str(loss) == "DistributionLoss()" else f"{model}-MAE",
            enable_model_summary=False,
            enable_checkpointing=False,
            enable_progress_bar=False,
            logger=False
        )
        for model, loss in itertools.product(
            ["TFT", "NHITS"], [MAE(), DistributionLoss("Normal", level=[]), DistributionLoss("StudentT", level=[])]
        )
    ],
    freq="M"
)
nf.fit(Y_df)

forecast = nf.predict(Y_df)

print(forecast.sum())

with v2.0.1

TFT-MAE                                                       16336.152344
TFT-Normal                                                    14885.810547
TFT-Normal-median                                             14902.537109
TFT-StudentT                                                  16206.841797
TFT-StudentT-median                                           16213.893555
NHITS-MAE                                                     16241.298828
NHITS-Normal                                                  16364.853516
NHITS-Normal-median                                           16366.496094
NHITS-StudentT                                                16165.771484
NHITS-StudentT-median                                         16168.845703 

with v3.0.0

TFT-MAE                                                       16336.152344
TFT-Normal                                                    16099.189453
TFT-Normal-median                                             16103.021484
TFT-StudentT                                                  16125.381836
TFT-StudentT-median                                           16130.875977
NHITS-MAE                                                     16241.298828
NHITS-Normal                                                  16229.751953
NHITS-Normal-median                                           16231.222656
NHITS-StudentT                                                16322.078125
NHITS-StudentT-median                                         16323.810547

Issue Severity

Medium: It is a significant difficulty but I can work around it.

Antoine-Schwartz avatar Apr 17 '25 16:04 Antoine-Schwartz

Thanks for creating the issue @Antoine-Schwartz. I'll have a look into it

elephaint avatar Apr 18 '25 06:04 elephaint

Hello @elephaint, this issue still exists in 3.0.2, do you have any clues? Don't hesitate if you need a crash tester :)

Antoine-Schwartz avatar Jun 20 '25 14:06 Antoine-Schwartz