neuralforecast
neuralforecast copied to clipboard
Significant gap in forecasts from v2.0.1 to v3.0.0 when using DistributionLoss
What happened + What you expected to happen
I suspect a bug, or at least an unintentional change introduced by the 3.0.0 release, which impacts forecasts when using a DistributionLoss. Sometimes it improves performance, sometimes not, depending on the nature of the data.
On the other hand, there's no change in results when using BasePointLoss from one version to the other.
Versions / Dependencies
neuralforecast==2.0.1 or neuralforecast==3.0.0
Reproduction script
import pandas as pd
import numpy as np
import itertools
from neuralforecast import NeuralForecast
from neuralforecast.models import DeepAR, TFT, NHITS
from neuralforecast.losses.pytorch import MAE, DistributionLoss
from neuralforecast.utils import AirPassengersPanel
Y_df = AirPassengersPanel
nf = NeuralForecast(
models=[
eval(model)(
h=12,
input_size=48,
max_steps=100,
scaler_type="standard",
loss=loss,
alias=f"{model}-{loss.distribution}" if str(loss) == "DistributionLoss()" else f"{model}-MAE",
enable_model_summary=False,
enable_checkpointing=False,
enable_progress_bar=False,
logger=False
)
for model, loss in itertools.product(
["TFT", "NHITS"], [MAE(), DistributionLoss("Normal", level=[]), DistributionLoss("StudentT", level=[])]
)
],
freq="M"
)
nf.fit(Y_df)
forecast = nf.predict(Y_df)
print(forecast.sum())
with v2.0.1
TFT-MAE 16336.152344
TFT-Normal 14885.810547
TFT-Normal-median 14902.537109
TFT-StudentT 16206.841797
TFT-StudentT-median 16213.893555
NHITS-MAE 16241.298828
NHITS-Normal 16364.853516
NHITS-Normal-median 16366.496094
NHITS-StudentT 16165.771484
NHITS-StudentT-median 16168.845703
with v3.0.0
TFT-MAE 16336.152344
TFT-Normal 16099.189453
TFT-Normal-median 16103.021484
TFT-StudentT 16125.381836
TFT-StudentT-median 16130.875977
NHITS-MAE 16241.298828
NHITS-Normal 16229.751953
NHITS-Normal-median 16231.222656
NHITS-StudentT 16322.078125
NHITS-StudentT-median 16323.810547
Issue Severity
Medium: It is a significant difficulty but I can work around it.
Thanks for creating the issue @Antoine-Schwartz. I'll have a look into it
Hello @elephaint, this issue still exists in 3.0.2, do you have any clues? Don't hesitate if you need a crash tester :)