pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

[BUG] Temporal Fusion Transformer trained on GPU then loaded on CPU does not propagate map_location into the loss metrics as expected

Open arizzuto opened this issue 1 year ago • 17 comments

Describe the bug I have trained a TFT on a machine with several GPU's. When loading the model on a smaller CPU-only machine as with:

model = TemporalFusionTransformer.load_from_checkpoint(model_file, map_location=torch.device('cpu'))

I get an error thrown:

AssertionError: Torch not compiled with CUDA enabled

I have tracked the problem to the loss metric for the model (QuantileLoss in my case) being initialised with a cuda device rather than the cpu device implied from the map_locations keyword arg above.

Here the full trace:

Traceback (most recent call last):
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/model_evaluation.py", line 45, in <module>
    main(model_path)
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/model_evaluation.py", line 22, in main
    model = TemporalFusionTransformer.load_from_checkpoint(model_file, map_location='cpu')
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1581, in load_from_checkpoint
    loaded = _load_from_checkpoint(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/pytorch/core/saving.py", line 100, in _load_from_checkpoint
    return model.to(device)
           ^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 55, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1340, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torchmetrics/metric.py", line 908, in _apply
    _dummy_tensor = fn(torch.zeros(1, device=self.device))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/cuda/__init__.py", line 310, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

is there a way to propagate the location_mapping into the loss metrics for the model?

arizzuto avatar Jan 15 '25 04:01 arizzuto

After some more digging, on line 62-63 of lightning/pytorch/core/saving.py

with pl_legacy_patch():
        checkpoint = pl_load(checkpoint_path, map_location=map_location)

with map_location = torch.device('cpu'), the returned checkpoint has:

>>> checkpoint['hyper_parameters']['loss'].device
device(type='cuda', index=0)

which is not the expected device(type='cpu')

arizzuto avatar Jan 15 '25 05:01 arizzuto

Has any workaround been found for this?

stefangordon avatar Mar 16 '25 14:03 stefangordon

My current workaround is after training on the gpus, I load the trained model on rank 0, forcing everything to CPU, then pickle the model. Loading the pickled model on on a machine with cpu only pytorch seems to then work fine.

arizzuto avatar Mar 16 '25 22:03 arizzuto

@arizzuto can you share the code piece for loading on rank 0 and forcing it to CPU? Would be helpful if you could show if you are dumping the model object resulted from best_tft itself or if you are dumping state_dict and using that to load the model on CPU

saurabh-sh704 avatar Mar 18 '25 14:03 saurabh-sh704

I've just encountered this issue after trying to update python, torch, pytorch-forecasting, and pytorch-lightning. It was previously working fine to train on GPU (EC2) and then load the model and predict with CPU (macbook)

@saurabh-sh704, @stefangordon, or @arizzuto would you like to try these older versions to test whether this bug is related to some subsequent update?

The configuration for which this works:

python = "~=3.10.0"
numpy = "==1.23.5"
pytorch-forecasting = "~=0.10.2"
pytorch-lightning = "~=1.8.0"
torch = "==1.13.1"

The versions which cause the issue:

python = "~=3.11.7"
numpy = "~=1.26.1"
pytorch-forecasting = "^1.0.0"
pytorch-lightning = "^2.0.0"
torch = "~=2.6.0"

mkuiack avatar Apr 17 '25 21:04 mkuiack

Hi @saurabh-sh704 , here's the block that does everything, creation of model and trainer are abstracted away here. Also, I'm not using bestTFT, I have a fixed process for training with parameters tuned with early stopping set up appropriately for my use case. The idea here is to load the model on a machine with gpu cuda, put it all onto the cpu, then save it.

        model = create_model(training_dataset, model_configuration.tft_model_params)
        trainer = create_trainer(model_configuration.trainer_params, model_configuration.early_stop_patience, logdir=log_dir)
        trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=val_dataloader)

        if os.getenv("LOCAL_RANK", '0') == '0':

            model_file = os.path.join(log_dir, 'lightning_logs/', 'version_0', 'checkpoints')
            model_file = glob.glob(model_file + '/*.ckpt')[0]
            loaded_model = TemporalFusionTransformer.load_from_checkpoint(model_file, map_location=torch.device('cpu'))
            pickle_path = model_file = os.path.join(log_dir, 'lightning_logs/', 'version_0', 'model_save.pkl')
            with open(pickle_path, 'wb') as f:
                pickle.dump(loaded_model, f)

arizzuto avatar Apr 22 '25 23:04 arizzuto

Here is another solution.

In this case you initialise the TFT model from the inference feature dataset, then you replace the internals of the model with the trained model parameters, removing the loss and logging_metrics since these are the components that retain their cuda dependency no matter what you do.

    def _deserialize(
            model_path: str | IO,
            data: pd.DataFrame
    ) -> TemporalFusionTransformer:
        LOGGER.info("Deserializing model object at path %s", model_path)

        # Load the checkpoint
        checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
        # Ensure the checkpoint contains the state_dict
        if "state_dict" not in checkpoint:
            raise ValueError("Checkpoint does not contain a state_dict.")

        # Extract hyperparameters and dataset parameters
        hyperparameters = checkpoint["hyper_parameters"]

        ## You can either:
        ## Replace CUDA QuantileLoss with a new CPU instance, do this for logging_metrics too
        # if "loss" in hyperparameters and hasattr(hyperparameters["loss"], "quantiles"):
        #     quantiles = hyperparameters["loss"].quantiles
        #     # Create a fresh QuantileLoss instance on CPU
        #     from pytorch_forecasting.metrics import QuantileLoss
        #     hyperparameters["loss"] = QuantileLoss(quantiles=quantiles)

        ## OR
        ## Clear metrics which are not needed for inference and contain GPU dependencies
        metrics_to_remove = ['loss', 'logging_metrics']
        for metric in metrics_to_remove:
            if metric in hyperparameters:
                hyperparameters.pop(metric, None)

        dataset_parameters = checkpoint["dataset_parameters"]

        # If a `TemporalFusionTransformer` was trained while weighing certain data
        # points, it still expects at inference time the weight column to be included
        # in the dataset even though it has no purpose or effect.
        dataset_parameters.pop("weight", None)

        # Initialise the dataset and model with the correct parameters
        dataset = TimeSeriesDataSet.from_parameters(dataset_parameters, data)
        tft = TemporalFusionTransformer.from_dataset(dataset, **hyperparameters)

        # Load the state_dict to the model 
        tft.load_state_dict(checkpoint["state_dict"])

        return tft

mkuiack avatar May 08 '25 11:05 mkuiack

pickle.dump

Thanks for the help. I tried this method to export the pkl file. When I load the file, the error occur TypeError: 'TemporalFusionTransformer' object is not subscriptable

mrgreen3325 avatar May 23 '25 07:05 mrgreen3325

Here is another solution.

In this case you initialise the TFT model from the inference feature dataset, then you replace the internals of the model with the trained model parameters, removing the loss and logging_metrics since these are the components that retain their cuda dependency no matter what you do.

def _deserialize(
        model_path: str | IO,
        data: pd.DataFrame
) -> TemporalFusionTransformer:
    LOGGER.info("Deserializing model object at path %s", model_path)

    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
    # Ensure the checkpoint contains the state_dict
    if "state_dict" not in checkpoint:
        raise ValueError("Checkpoint does not contain a state_dict.")

    # Extract hyperparameters and dataset parameters
    hyperparameters = checkpoint["hyper_parameters"]

    ## You can either:
    ## Replace CUDA QuantileLoss with a new CPU instance, do this for logging_metrics too
    # if "loss" in hyperparameters and hasattr(hyperparameters["loss"], "quantiles"):
    #     quantiles = hyperparameters["loss"].quantiles
    #     # Create a fresh QuantileLoss instance on CPU
    #     from pytorch_forecasting.metrics import QuantileLoss
    #     hyperparameters["loss"] = QuantileLoss(quantiles=quantiles)

    ## OR
    ## Clear metrics which are not needed for inference and contain GPU dependencies
    metrics_to_remove = ['loss', 'logging_metrics']
    for metric in metrics_to_remove:
        if metric in hyperparameters:
            hyperparameters.pop(metric, None)

    dataset_parameters = checkpoint["dataset_parameters"]

    # If a `TemporalFusionTransformer` was trained while weighing certain data
    # points, it still expects at inference time the weight column to be included
    # in the dataset even though it has no purpose or effect.
    dataset_parameters.pop("weight", None)

    # Initialise the dataset and model with the correct parameters
    dataset = TimeSeriesDataSet.from_parameters(dataset_parameters, data)
    tft = TemporalFusionTransformer.from_dataset(dataset, **hyperparameters)

    # Load the state_dict to the model 
    tft.load_state_dict(checkpoint["state_dict"])

    return tft

Thanks. When I use the method, model =TemporalFusionTransformer.load_from_checkpoint('tft.pkl',map_location=torch.device('cpu'))the error occur: IndexError: index 3 is out of bounds for dimension 1 with size 1 Do you have this too?

mrgreen3325 avatar May 23 '25 07:05 mrgreen3325

Hi @mrgreen3325
tft.pkl is a pickle object not a checkpoint. Checkpoints are .ckpt files automatically created during training. by defaul they're in a lightning_logs directory and have names like epoch=0-step=768.ckpt

mkuiack avatar May 23 '25 09:05 mkuiack

Just to be clear @mrgreen3325, once you've created the pickled TFT (tft.pkl file) you load it back in like:

with open('tft.pkl','rb') as f: model = pickle.load(f)

arizzuto avatar May 25 '25 23:05 arizzuto

Hi @mrgreen3325 tft.pkl is a pickle object not a checkpoint. Checkpoints are .ckpt files automatically created during training. by defaul they're in a lightning_logs directory and have names like epoch=0-step=768.ckpt

Thanks, may i know which function I need to use to open the pkl? torch.load?

mrgreen3325 avatar May 26 '25 02:05 mrgreen3325

Hi @mrgreen3325 tft.pkl is a pickle object not a checkpoint. Checkpoints are .ckpt files automatically created during training. by defaul they're in a lightning_logs directory and have names like epoch=0-step=768.ckpt

Thanks, may i know which function I need to use to open the pkl? torch.load?

import pickle

pickle stores the entire TFT object instance (or whatever object you pickle) no reference to the original packages required

arizzuto avatar May 26 '25 02:05 arizzuto

Hi @mrgreen3325 tft.pkl is a pickle object not a checkpoint. Checkpoints are .ckpt files automatically created during training. by defaul they're in a lightning_logs directory and have names like epoch=0-step=768.ckpt

Thanks, may i know which function I need to use to open the pkl? torch.load?

import pickle

pickle stores the entire TFT object instance (or whatever object you pickle) no reference to the original packages required

Thanks for reply. May I know that how to set the tsdataset = TimeSeriesDataSet()? Since the data are encoded during training, is it need to encoded in the same way for the test set during inference? Thanks.

mrgreen3325 avatar May 27 '25 01:05 mrgreen3325

Thanks for reply. May I know that how to set the tsdataset = TimeSeriesDataSet()? Since the data are encoded during training, is it need to encoded in the same way for the test set during inference? Thanks.

This is the same for any TFT, I would recommend looking at some of the great example uses that the pytorch-forecasting devs have created.

E.g. https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html

arizzuto avatar May 27 '25 02:05 arizzuto