darts icon indicating copy to clipboard operation
darts copied to clipboard

Help Reproducing the TiDE Paper

Open kevincoakley opened this issue 10 months ago • 6 comments

Hi, I am trying to reproduce the TiDE paper (https://arxiv.org/abs/2304.08424) using Darts and I am having some issues.

The paper reports results MSE: 0.454 MAE: 0.465 for the ETTh1 (720) dataset, however I am getting MSE: 0.026 MAE 0.127. Am I doing something wrong with the data or the evaluation? Any pointers would be appreciated. Thank you.

from darts.models import TiDEModel
from darts.datasets import ETTh1Dataset
from darts.dataprocessing.transformers.scaler import Scaler
from darts.metrics import mae, mse


epochs = 20
lookback = 720
horizon = 720
series = ETTh1Dataset().load()

# hyperparameters from the paper
hidden_size = 256
num_encoder_layers = 2
num_decoder_layers = 2
decoder_output_dim = 8
temporal_decoder_hidden = 128
dropout = .3
use_layer_norm =  True
lr = 3.82e-5
rev_in = True
batch_size = 512

pl_trainer_kwargs = {
    "max_epochs": epochs,
}

optimizer_kwargs = {
    "lr": lr,
}

# 6:2:2 split
train, temp = series.split_after(0.6)
val, test = temp.split_after(0.5)

scaler = Scaler()
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

model_tide = TiDEModel(
    input_chunk_length = lookback,
    output_chunk_length = horizon,

    hidden_size = hidden_size,
    num_encoder_layers = num_encoder_layers,
    num_decoder_layers = num_decoder_layers,
    decoder_output_dim = decoder_output_dim,
    temporal_decoder_hidden = temporal_decoder_hidden,
    dropout = dropout,
    use_layer_norm = use_layer_norm,
    use_reversible_instance_norm=rev_in,
    batch_size = batch_size,

    pl_trainer_kwargs = pl_trainer_kwargs,
    optimizer_kwargs = optimizer_kwargs,
    save_checkpoints = True, 
    force_reset = True,
    model_name="tide"
)

model_tide.fit(
    series=train,
    val_series=val,
)

best_model = model_tide.load_from_checkpoint(model_name="tide", best=True)

best_model.save("tide_best_model")

predict = best_model.predict(n=len(test), series=val)

pred_mae = mae(test, predict)
pred_mse = mse(test, predict)

print("mae:", pred_mae)
print("mse:", pred_mse)

forcasts = best_model.historical_forecasts(series=test, 
                                           start=0, 
                                           forecast_horizon=horizon, 
                                           retrain=False)

forcast_mae = mae(test, forcasts)
forcast_mse = mse(test, forcasts)

print("forcast_mae:", forcast_mae)
print("forcast_mse:", forcast_mse)

kevincoakley avatar Apr 19 '24 19:04 kevincoakley

Hey! There are a few issues I can see at a high level with the replication.

Scaler in darts defaults to min-max scaling when they used standardization in the paper. Your data split is also different relative to the paper.

All models were trained using MSE as the training loss. In all the datasets, the train:validation:test ratio is 7:1:2 as dictated by prior work. Note that all the experiments are performed on standard normalized datasets (using the mean and the standard deviations in the training period) in order to be consitent with prior work (Wu et al., 2021).****

They also use learning rate scheduling in the paper.

We also tune the maximum learningRate which is the input to a cosine decay learning rate schedule

alexcolpitts96 avatar Apr 20 '24 17:04 alexcolpitts96

Hi @alexcolpitts96 Thank you for taking the time to reply!

I wasn't sure what to do with the data splits, the paper does say 7:1:2, but their code has 6:2:2, unless I am misinterpreting their code. I tried both and it didn't change the results significantly.

https://github.com/google-research/google-research/blob/a3e7b75d49edc68c36487b2188fa834e02c12986/tide/train.py#L95

I made the other changes and the results were MAE: 0.91 MSE: 1.39, now higher than the reported results. I'll include my updated code. Any other suggestions on what I might be doing wrong? Thank you again!

import torch
from darts.models import TiDEModel
from darts.datasets import ETTh1Dataset
from darts.dataprocessing.transformers.scaler import Scaler
from darts.metrics import mae, mse
from sklearn.preprocessing import StandardScaler

epochs = 100
lookback = 720
horizon = 720
series = ETTh1Dataset().load()

# hyperparameters from the paper
hidden_size = 256
num_encoder_layers = 2
num_decoder_layers = 2
decoder_output_dim = 8
temporal_decoder_hidden = 128
dropout = .3
use_layer_norm =  True
lr = 3.82e-5
rev_in = True
batch_size = 512

# 7:1:2
train, temp = series.split_after(0.7)
val, test = temp.split_after(0.33)

standard_scaler = StandardScaler()

scaler = Scaler(standard_scaler)
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)

# PyTorch Lightning trainer
pl_trainer_kwargs = {
    "max_epochs": epochs,
}

optimizer_kwargs = {
    "lr": lr,
}

# learning rate scheduler
lr_scheduler_cls = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
lr_scheduler_kwargs = {
    "T_0": 10,
}


model_tide = TiDEModel(
    input_chunk_length = lookback,
    output_chunk_length = horizon,

    hidden_size = hidden_size,
    num_encoder_layers = num_encoder_layers,
    num_decoder_layers = num_decoder_layers,
    decoder_output_dim = decoder_output_dim,
    temporal_decoder_hidden = temporal_decoder_hidden,
    dropout = dropout,
    use_layer_norm = use_layer_norm,
    use_reversible_instance_norm=rev_in,
    batch_size = batch_size,

    pl_trainer_kwargs = pl_trainer_kwargs,
    optimizer_kwargs = optimizer_kwargs,
    lr_scheduler_cls = lr_scheduler_cls,
    lr_scheduler_kwargs = lr_scheduler_kwargs,
    save_checkpoints = True, 
    force_reset = True,
    model_name="tide"
)

model_tide.fit(
    series=train,
    val_series=val,
)

best_model = model_tide.load_from_checkpoint(model_name="tide", best=True)

best_model.save("tide_best_model")

predict = best_model.predict(n=len(test), series=val)

pred_mae = mae(test, predict)
pred_mse = mse(test, predict)

print("mae:", pred_mae)
print("mse:", pred_mse)

###########

forcasts = best_model.historical_forecasts(series=test, 
                                           start=0, 
                                           forecast_horizon=horizon, 
                                           retrain=False)

forcast_mae = mae(test, forcasts)
forcast_mse = mse(test, forcasts)

print("forcast_mae:", forcast_mae)
print("forcast_mse:", forcast_mse)

kevincoakley avatar Apr 20 '24 20:04 kevincoakley

I see that they used gradient clipping. You can use it with darts by passing it to pl_trainer_kwargs. Ex: gradient_clip_val=0.5. https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html

I am also not sure you configured the learning rate scheduler to be identical. I am not super familiar with the settings since I generally don't use cosine annealing.

They also have seeds posted (np.random.seed(1024), tf.random.set_seed(1024)), but I am not sure if they ended up using them in the published results since they have an override: https://github.com/google-research/google-research/blob/a3e7b75d49edc68c36487b2188fa834e02c12986/tide/train.py#L132

alexcolpitts96 avatar Apr 20 '24 20:04 alexcolpitts96

Hi @kevincoakley, you should also inverse transform the historical forecasts predictions and compute the metrics against the un-transformed test series.

dennisbader avatar Apr 21 '24 16:04 dennisbader

Hi @dennisbader Thank you for taking the time to answer my question. I am not clear what you mean by un-transformed test series. I have been confused if I was evaluating the predictions properly.

Do you mean something like this?

# Insert this before setting test = scaler.transform(test)
untransformed_test = test
...
test = scaler.transform(test)
forecasts = best_model.historical_forecasts(series=test, 
                                           start=0, 
                                           forecast_horizon=horizon, 
                                           retrain=False)

forecasts_mae = mae(untransformed_test, scaler.inverse_transform(forecasts))
forecasts_mse = mse(untransformed_test, scaler.inverse_transform(forecasts))

print("forcast_mae:", forecasts_mae)
print("forcast_mse:", forecasts_mse)

kevincoakley avatar Apr 21 '24 17:04 kevincoakley

@kevincoakley, yes that's it :)

dennisbader avatar Apr 22 '24 09:04 dennisbader