pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

Temporal Fusion Transformer gives constant prediction

Open manitadayon opened this issue 1 year ago • 40 comments

  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.0.1
  • Python version: 3.10
  • Operating System: Windows

Expected behavior

My expectation is to get a prediction that is not completely constant and can capture some volatility in the curve.

Actual behavior

No matter how much I tune hyperparameters, I get pure constant/linear prediction.

Code to reproduce the problem

max_encoder_length = 100
max_prediction_length = 100

context_length = max_encoder_length
prediction_length = max_prediction_length

training = TimeSeriesDataSet(
    train_data,
    time_idx="new_time",
    target="units",
    group_ids=["marketplace"],
    time_varying_unknown_reals=["units"],
    max_encoder_length=context_length,
    max_prediction_length=prediction_length,
    allow_missing_timesteps=True,
    add_relative_time_idx=True,
    static_categoricals = ["marketplace"],
    target_normalizer = GroupNormalizer(groups = ['marketplace'], transformation = 'softplus')
)

batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)


net  = TemporalFusionTransformer.from_dataset(
training, learning_rate=0.1, hidden_size=27, attention_head_size=2, 
dropout=0.18, hidden_continuous_size=16, log_interval=1, log_val_interval = 1,output_size=7,
            loss=QuantileLoss(), lstm_layers = 3
)

trainer.fit(
    net,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

The result after prediction is:


for idx in range(5):
    best_model.plot_prediction(raw_predictions.x,raw_predictions.output, idx=idx)

predicted_actual_TFT

Can someone help me understand what is wrong with my code? I have also tried it on different datasets and still had the same issue.

manitadayon avatar Jun 17 '23 06:06 manitadayon

I have encountered the same problem and have no idea how to capture some volatility in the prediction. I have searched the web for possible solutions but have not found any.

wkkg avatar Jun 29 '23 17:06 wkkg

Me too, hoping the original creator or anyone has any input on this. I am not sure if there is a problem with TFT algorithm or the implementation. Have you used darts or other packages for this.

manitadayon avatar Jun 29 '23 21:06 manitadayon

Same issue here. I did notice after including multiple categories in my data I got more variation in the prediction

meetri avatar Jun 30 '23 00:06 meetri

Can you please elaborate on what you mean by including more categories? Are you talking about static categories? You mean more predictors? I can try it and update it here.

manitadayon avatar Jun 30 '23 01:06 manitadayon

I have tried various combinations of hyperparameters and implementation from https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91 the forecasts still remain flat. If I find the solution, I will let you know. Could you do the same? What about posting this problem on stackoverflow?

From: Chess_champion @.*** Sent: Thursday, June 29, 2023 11:44 PM To: jdb78/pytorch-forecasting @.> Cc: wkkg @.>; Comment @.***> Subject: Re: [jdb78/pytorch-forecasting] Temporal Fusion Transformer gives constant prediction (Issue #1332)

Me too, hoping the original creator or anyone has any input on this. I am not sure if there is a problem with TFT algorithm or the implementation.

— Reply to this email directly, view it on GitHub https://github.com/jdb78/pytorch-forecasting/issues/1332#issuecomment-1613844502 , or unsubscribe https://github.com/notifications/unsubscribe-auth/ALSCSVLOR5COZOZFZMNVCY3XNXZJVANCNFSM6AAAAAAZKA4RVE . You are receiving this because you commented. https://github.com/notifications/beacon/ALSCSVNY274XD3ORWCGQ42LXNXZJVA5CNFSM6AAAAAAZKA4RVGWGG33NNVSW45C7OR4XAZNMJFZXG5LFINXW23LFNZ2KUY3PNVWWK3TUL5UWJTTAGFIBM.gif Message ID: @.*** @.***> >

wkkg avatar Jun 30 '23 06:06 wkkg

I have done the same and tried everything, but the forecast is still flat. If I find out the solution, I will let you know (post it here or StackOverflow)

manitadayon avatar Jun 30 '23 07:06 manitadayon

Can you please elaborate on what you mean by including more categories? Are you talking about static categories? You mean more predictors? I can try it and update it here.

Yes. Static categories. Curious what you end up finding. Looking forward to your update

meetri avatar Jun 30 '23 16:06 meetri

I tried adding other predictors or adding group variables as static variables, playing with some hyper-parameters but still nothing, constant prediction.

manitadayon avatar Jul 01 '23 03:07 manitadayon

We have the same problem. We started with pytorch-forecasting on a previous version, and kept going on with a custom implementation based on playtika . Training runs fine, and in-sample evaluation looks good. We have a very similar setup with pytorch and lightning > 2.0 . Most probably it is an issue with pytorch and/or lightning in v2.

meteoDaniel avatar Jul 05 '23 15:07 meteoDaniel

@meteoDaniel, thanks this is very interesting, do you recommend downgrading version of pytorch lightening (seems to be suboptimal solution) have you tried darts?

ManieTadayon avatar Jul 05 '23 15:07 ManieTadayon

@ManieTadayon we are actually trying this. For debugging it is a good idea to find the problem. We will keep you updated as soon as we have some helpful insights.

meteoDaniel avatar Jul 05 '23 15:07 meteoDaniel

Sorry I'm not close to the project anymore! Good luck

JakeF-Bitweave avatar Jul 05 '23 18:07 JakeF-Bitweave

My expectation is to get a prediction that is not completely constant and can capture some volatility in the curve.

Hi,

is it really completely constant, as you said? Looking at your picture, I would say that the prediction is not constant.

Best regards

grosestq avatar Jul 05 '23 19:07 grosestq

It learns absolutely no volatility compared to similar models regardless of features added or hyperparameter tuning. the change in forecast is minimal (hence almost constant).

ManieTadayon avatar Jul 05 '23 19:07 ManieTadayon

It learns absolutely no volatility compared to similar models regardless of features added or hyperparameter tuning. the change in forecast is minimal (hence almost constant).

How does the prediction accuracy compare to similar models?

grosestq avatar Jul 05 '23 19:07 grosestq

Since it does not capture any volatility, then the accuracy or MAPE would not be good, but regardless of error, we need to capture the volatility in data.

manitadayon avatar Jul 07 '23 02:07 manitadayon

I am having the same problem here: prediction_B2F8KN-R_1 It seems to predict the average alright but no volatility at all. I am experimenting with encoder lenght but no dice so far.

Very strange since the valuation loss does actually go down...: WhatsApp Image 2023-07-15 at 12 21 47

Will let everyone know if I find a solution...

valentinfrlch avatar Jul 15 '23 11:07 valentinfrlch

Hi everyone, I think I fixed this issue, the problem is coming from decoder side and relative_idx is not enough to learn any volatility. Mine now has a acceptable prediction for at least one particular example

manitadayon avatar Jul 18 '23 03:07 manitadayon

May I ask what other features you fed? Something like day of week, month etc. or more actual data?

valentinfrlch avatar Jul 18 '23 04:07 valentinfrlch

Sure, yes the day of week, month, year, etc help a lot in capturing the volatility.

manitadayon avatar Jul 21 '23 15:07 manitadayon

Has there been any more insight in the meantime? Because I don't think adding features is the solution.

It is true that those add volatility, but this would still entirely come from those features like in a regression, not at all from the encoded input sequence. For example, when you have a weekly pattern and shift the input data by three days, all predictions would be off by those three days because the model did not pick up anything from encoder side and just repeated the statically learned weekday pattern. At least that's my experience even from the simplest toy datasets.

Maybe I have wrong expectations but I would have thought this kind of model can dynamically pick up patterns from the encoded sequence, like an ETS or Arima Model does. So I still wonder if this is a problem just in this implementation.

HHarald99 avatar Sep 04 '23 17:09 HHarald99

For me it was the normalizer. Softplus should be pretty good for most scenarios, I couldn't use it however since my numbers where too small at times. Check out GroupNormalizer in the docs. I had luck with relu but your mileage may vary.

valentinfrlch avatar Sep 04 '23 18:09 valentinfrlch

@valentinfrlch for clarity, are you saying that disabling GroupNormalizer is what did the trick for you?

abudis avatar Sep 28 '23 08:09 abudis

@valentinfrlch, no by default TFT, NHITS, DeepAR, etc are normalizing the data sequence by sequence, you can configure it such that it is done by time series by time series.

manitadayon avatar Sep 29 '23 05:09 manitadayon

@meteoDaniel Any update on your findings? I do not think it is the package.

manitadayon avatar Sep 29 '23 06:09 manitadayon

Hi,

Has anyone found the root cause of this and how to fix it? I got the same issue with the flat prediction. I have been trying to fix it for a while now.

rd1886 avatar Nov 07 '23 15:11 rd1886

To be honest I got some good prediction out of TFT. Try adding static variables, temporal variables, etc. However it still looks very suboptimal. For example, if you use NHiTS, you will get a good prediction without specifying much external covariates.

manitadayon avatar Nov 07 '23 16:11 manitadayon

Thank you, @manitadayon

rd1886 avatar Nov 08 '23 18:11 rd1886

Hi guys, I switched to Darts since its clear that pytorch-forecasting and TFT is not a good combination. I spend 500 hours of hyperparameter tuning, tried the most strange combinations but whatever I tried, it stays flat. I found this magnificant article but you have to build it yourself here: https://www.playtika-blog.com/playtika-ai/multi-horizon-forecasting-using-temporal-fusion-transformers-a-comprehensive-overview-part-2/ So I switched to Darts and it was working instantly with the most simple training.: https://unit8co.github.io/darts/examples/13-TFT-examples.html

erwinvink avatar Nov 14 '23 18:11 erwinvink

@erwinvink thank you. However I am confused, did you end up using darts or tft-PyTorch: since the article shows tft-pytorch packages and then you mentioned using darts again. did you use steps outlined there only for data preprocessing and then darts for training? May I know why not the package itself for the whole thing.

manitadayon avatar Nov 14 '23 19:11 manitadayon