pytorch-forecasting
pytorch-forecasting copied to clipboard
Temporal Fusion Transformer gives constant prediction
- PyTorch-Forecasting version: 1.0.0
- PyTorch version: 2.0.1
- Python version: 3.10
- Operating System: Windows
Expected behavior
My expectation is to get a prediction that is not completely constant and can capture some volatility in the curve.
Actual behavior
No matter how much I tune hyperparameters, I get pure constant/linear prediction.
Code to reproduce the problem
max_encoder_length = 100
max_prediction_length = 100
context_length = max_encoder_length
prediction_length = max_prediction_length
training = TimeSeriesDataSet(
train_data,
time_idx="new_time",
target="units",
group_ids=["marketplace"],
time_varying_unknown_reals=["units"],
max_encoder_length=context_length,
max_prediction_length=prediction_length,
allow_missing_timesteps=True,
add_relative_time_idx=True,
static_categoricals = ["marketplace"],
target_normalizer = GroupNormalizer(groups = ['marketplace'], transformation = 'softplus')
)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
net = TemporalFusionTransformer.from_dataset(
training, learning_rate=0.1, hidden_size=27, attention_head_size=2,
dropout=0.18, hidden_continuous_size=16, log_interval=1, log_val_interval = 1,output_size=7,
loss=QuantileLoss(), lstm_layers = 3
)
trainer.fit(
net,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)
The result after prediction is:
for idx in range(5):
best_model.plot_prediction(raw_predictions.x,raw_predictions.output, idx=idx)
Can someone help me understand what is wrong with my code? I have also tried it on different datasets and still had the same issue.
I have encountered the same problem and have no idea how to capture some volatility in the prediction. I have searched the web for possible solutions but have not found any.
Me too, hoping the original creator or anyone has any input on this. I am not sure if there is a problem with TFT algorithm or the implementation. Have you used darts or other packages for this.
Same issue here. I did notice after including multiple categories in my data I got more variation in the prediction
Can you please elaborate on what you mean by including more categories? Are you talking about static categories? You mean more predictors? I can try it and update it here.
I have tried various combinations of hyperparameters and implementation from https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91 the forecasts still remain flat. If I find the solution, I will let you know. Could you do the same? What about posting this problem on stackoverflow?
From: Chess_champion @.*** Sent: Thursday, June 29, 2023 11:44 PM To: jdb78/pytorch-forecasting @.> Cc: wkkg @.>; Comment @.***> Subject: Re: [jdb78/pytorch-forecasting] Temporal Fusion Transformer gives constant prediction (Issue #1332)
Me too, hoping the original creator or anyone has any input on this. I am not sure if there is a problem with TFT algorithm or the implementation.
— Reply to this email directly, view it on GitHub https://github.com/jdb78/pytorch-forecasting/issues/1332#issuecomment-1613844502 , or unsubscribe https://github.com/notifications/unsubscribe-auth/ALSCSVLOR5COZOZFZMNVCY3XNXZJVANCNFSM6AAAAAAZKA4RVE . You are receiving this because you commented. https://github.com/notifications/beacon/ALSCSVNY274XD3ORWCGQ42LXNXZJVA5CNFSM6AAAAAAZKA4RVGWGG33NNVSW45C7OR4XAZNMJFZXG5LFINXW23LFNZ2KUY3PNVWWK3TUL5UWJTTAGFIBM.gif Message ID: @.*** @.***> >
I have done the same and tried everything, but the forecast is still flat. If I find out the solution, I will let you know (post it here or StackOverflow)
Can you please elaborate on what you mean by including more categories? Are you talking about static categories? You mean more predictors? I can try it and update it here.
Yes. Static categories. Curious what you end up finding. Looking forward to your update
I tried adding other predictors or adding group variables as static variables, playing with some hyper-parameters but still nothing, constant prediction.
We have the same problem. We started with pytorch-forecasting on a previous version, and kept going on with a custom implementation based on playtika . Training runs fine, and in-sample evaluation looks good. We have a very similar setup with pytorch and lightning > 2.0 . Most probably it is an issue with pytorch and/or lightning in v2.
@meteoDaniel, thanks this is very interesting, do you recommend downgrading version of pytorch lightening (seems to be suboptimal solution) have you tried darts?
@ManieTadayon we are actually trying this. For debugging it is a good idea to find the problem. We will keep you updated as soon as we have some helpful insights.
Sorry I'm not close to the project anymore! Good luck
My expectation is to get a prediction that is not completely constant and can capture some volatility in the curve.
Hi,
is it really completely constant, as you said? Looking at your picture, I would say that the prediction is not constant.
Best regards
It learns absolutely no volatility compared to similar models regardless of features added or hyperparameter tuning. the change in forecast is minimal (hence almost constant).
It learns absolutely no volatility compared to similar models regardless of features added or hyperparameter tuning. the change in forecast is minimal (hence almost constant).
How does the prediction accuracy compare to similar models?
Since it does not capture any volatility, then the accuracy or MAPE would not be good, but regardless of error, we need to capture the volatility in data.
I am having the same problem here:
It seems to predict the average alright but no volatility at all. I am experimenting with encoder lenght but no dice so far.
Very strange since the valuation loss does actually go down...:
Will let everyone know if I find a solution...
Hi everyone, I think I fixed this issue, the problem is coming from decoder side and relative_idx is not enough to learn any volatility. Mine now has a acceptable prediction for at least one particular example
May I ask what other features you fed? Something like day of week, month etc. or more actual data?
Sure, yes the day of week, month, year, etc help a lot in capturing the volatility.
Has there been any more insight in the meantime? Because I don't think adding features is the solution.
It is true that those add volatility, but this would still entirely come from those features like in a regression, not at all from the encoded input sequence. For example, when you have a weekly pattern and shift the input data by three days, all predictions would be off by those three days because the model did not pick up anything from encoder side and just repeated the statically learned weekday pattern. At least that's my experience even from the simplest toy datasets.
Maybe I have wrong expectations but I would have thought this kind of model can dynamically pick up patterns from the encoded sequence, like an ETS or Arima Model does. So I still wonder if this is a problem just in this implementation.
For me it was the normalizer. Softplus should be pretty good for most scenarios, I couldn't use it however since my numbers where too small at times. Check out GroupNormalizer in the docs. I had luck with relu but your mileage may vary.
@valentinfrlch for clarity, are you saying that disabling GroupNormalizer is what did the trick for you?
@valentinfrlch, no by default TFT, NHITS, DeepAR, etc are normalizing the data sequence by sequence, you can configure it such that it is done by time series by time series.
@meteoDaniel Any update on your findings? I do not think it is the package.
Hi,
Has anyone found the root cause of this and how to fix it? I got the same issue with the flat prediction. I have been trying to fix it for a while now.
To be honest I got some good prediction out of TFT. Try adding static variables, temporal variables, etc. However it still looks very suboptimal. For example, if you use NHiTS, you will get a good prediction without specifying much external covariates.
Thank you, @manitadayon
Hi guys, I switched to Darts since its clear that pytorch-forecasting and TFT is not a good combination. I spend 500 hours of hyperparameter tuning, tried the most strange combinations but whatever I tried, it stays flat. I found this magnificant article but you have to build it yourself here: https://www.playtika-blog.com/playtika-ai/multi-horizon-forecasting-using-temporal-fusion-transformers-a-comprehensive-overview-part-2/ So I switched to Darts and it was working instantly with the most simple training.: https://unit8co.github.io/darts/examples/13-TFT-examples.html
@erwinvink thank you. However I am confused, did you end up using darts or tft-PyTorch: since the article shows tft-pytorch packages and then you mentioned using darts again. did you use steps outlined there only for data preprocessing and then darts for training? May I know why not the package itself for the whole thing.