darts icon indicating copy to clipboard operation
darts copied to clipboard

TFT Model resume training from checkpoint not matching continuous training

Open wehrlik opened this issue 1 year ago • 1 comments

Hi there! Due to a shortage of computing ressources, I was forced to send short jobs. After returning to longer jobs, I realised that the trained models are not exactly the same if trained continuously or loaded again from checkpoint. I am using TFTModel.load_from_checkpoint as suggested in https://github.com/unit8co/darts/pull/1689#issuecomment-1500346963. Not having the same behaviour might cause issues with reproducibility, if the training has to be interrupted and I wonder, whether there are other artifacts of this. My question is, whether this is a bug or if something is missing that would have to be loaded apart from the checkpoint file or somehow altered after loading?

To reproduce I used the example with the air passenger data to illustrate the behaviour with a more simple example than my own model. I interrupted the training several times by killing the job, first after around Epoch 33. As illustrated in the following figure, the training loss diverges after the interruption of the training (orange corresponds to the continuous run and violet to the one where training was resumed several times).

image

import os
import numpy as np
import pandas as pd

from pytorch_lightning import Trainer

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import TFTModel

from darts.datasets import AirPassengersDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression
from darts.models.forecasting.torch_forecasting_model import _get_checkpoint_fname

# Read data
series = AirPassengersDataset().load()

# we convert monthly number of passengers to average daily number of passengers per month
series = series / TimeSeries.from_series(series.time_index.days_in_month)
series = series.astype(np.float32)

# Create training and validation sets:
training_cutoff = pd.Timestamp("19571201")
train, val = series.split_after(training_cutoff)

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

# create year, month and integer index covariate series
covariates = datetime_attribute_timeseries(series, attribute="year", one_hot=False)
covariates = covariates.stack(
    datetime_attribute_timeseries(series, attribute="month", one_hot=False)
)
covariates = covariates.stack(
    TimeSeries.from_times_and_values(
        times=series.time_index,
        values=np.arange(len(series)),
        columns=["linear_increase"],
    )
)
covariates = covariates.astype(np.float32)

# transform covariates
scaler_covs = Scaler()
cov_train, cov_val = covariates.split_after(training_cutoff)
scaler_covs.fit(cov_train)
covariates_transformed = scaler_covs.transform(covariates)

workdir = 'load_model_example/'
model_name = 'test_model_reload'

# define model
quantiles = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
input_chunk_length = 24
forecast_horizon = 12

def define_model():
    my_model = TFTModel(
        model_name=model_name,
        input_chunk_length=input_chunk_length,
        output_chunk_length=forecast_horizon,
        hidden_size=64,
        lstm_layers=1,
        num_attention_heads=4,
        dropout=0.1,
        batch_size=16,
        n_epochs=300,
        work_dir = workdir,
        save_checkpoints=True,
        log_tensorboard=True,
        pl_trainer_kwargs={"log_every_n_steps": 5},
        add_relative_index=False,
        add_encoders=None,
        likelihood=QuantileRegression(quantiles=quantiles),
        random_state=42,
    )
    return my_model

try:
    checkpoint_file = _get_checkpoint_fname(workdir, model_name, best=False)
    model = TFTModel.load_from_checkpoint(file_name=checkpoint_file, work_dir=workdir,model_name=model_name)
except FileNotFoundError:
    model = define_model()
trainer_params = model.trainer_params
trainer = Trainer(**trainer_params)

model.fit(train_transformed, future_covariates=covariates_transformed, trainer=trainer, verbose=False)

Additional context In my real-life example, I did not let the job die but trained for 5 epochs only, then loaded the model from checkpoint and reset the max_epochs in the trainer (as described here https://github.com/unit8co/darts/issues/1090#issuecomment-1193999342). My train loss curve is shown below, with a very prominent pattern repeating every 5 epochs. So it seems interrupting the training (very often) leads to some artifacts. image

System darts 0.30.0 torch 2.3.1 Python 3.10.11

wehrlik avatar Jul 18 '24 08:07 wehrlik

Hi @wehrlik, and thanks for raising this issue. We're investigating it but it'll take some time due to low capacity at the moment.

dennisbader avatar Jul 24 '24 07:07 dennisbader