[BUG] Temporal Fusion Transformer trained on GPU then loaded on CPU does not propagate map_location into the loss metrics as expected
Describe the bug I have trained a TFT on a machine with several GPU's. When loading the model on a smaller CPU-only machine as with:
model = TemporalFusionTransformer.load_from_checkpoint(model_file, map_location=torch.device('cpu'))
I get an error thrown:
AssertionError: Torch not compiled with CUDA enabled
I have tracked the problem to the loss metric for the model (QuantileLoss in my case) being initialised with a cuda device rather than the cpu device implied from the map_locations keyword arg above.
Here the full trace:
Traceback (most recent call last):
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/model_evaluation.py", line 45, in <module>
main(model_path)
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/model_evaluation.py", line 22, in main
model = TemporalFusionTransformer.load_from_checkpoint(model_file, map_location='cpu')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
return self.method(cls, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1581, in load_from_checkpoint
loaded = _load_from_checkpoint(
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/pytorch/core/saving.py", line 100, in _load_from_checkpoint
return model.to(device)
^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 55, in to
return super().to(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1340, in to
return self._apply(convert)
^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 900, in _apply
module._apply(fn)
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 900, in _apply
module._apply(fn)
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torchmetrics/metric.py", line 908, in _apply
_dummy_tensor = fn(torch.zeros(1, device=self.device))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronrizzuto/code/spot_price_forecast/forecast_reals_model/venv/lib/python3.12/site-packages/torch/cuda/__init__.py", line 310, in _lazy_init
raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
is there a way to propagate the location_mapping into the loss metrics for the model?
After some more digging, on line 62-63 of lightning/pytorch/core/saving.py
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)
with map_location = torch.device('cpu'), the returned checkpoint has:
>>> checkpoint['hyper_parameters']['loss'].device
device(type='cuda', index=0)
which is not the expected device(type='cpu')
Has any workaround been found for this?
My current workaround is after training on the gpus, I load the trained model on rank 0, forcing everything to CPU, then pickle the model. Loading the pickled model on on a machine with cpu only pytorch seems to then work fine.
@arizzuto can you share the code piece for loading on rank 0 and forcing it to CPU? Would be helpful if you could show if you are dumping the model object resulted from best_tft itself or if you are dumping state_dict and using that to load the model on CPU
I've just encountered this issue after trying to update python, torch, pytorch-forecasting, and pytorch-lightning. It was previously working fine to train on GPU (EC2) and then load the model and predict with CPU (macbook)
@saurabh-sh704, @stefangordon, or @arizzuto would you like to try these older versions to test whether this bug is related to some subsequent update?
The configuration for which this works:
python = "~=3.10.0"
numpy = "==1.23.5"
pytorch-forecasting = "~=0.10.2"
pytorch-lightning = "~=1.8.0"
torch = "==1.13.1"
The versions which cause the issue:
python = "~=3.11.7"
numpy = "~=1.26.1"
pytorch-forecasting = "^1.0.0"
pytorch-lightning = "^2.0.0"
torch = "~=2.6.0"
Hi @saurabh-sh704 , here's the block that does everything, creation of model and trainer are abstracted away here. Also, I'm not using bestTFT, I have a fixed process for training with parameters tuned with early stopping set up appropriately for my use case. The idea here is to load the model on a machine with gpu cuda, put it all onto the cpu, then save it.
model = create_model(training_dataset, model_configuration.tft_model_params)
trainer = create_trainer(model_configuration.trainer_params, model_configuration.early_stop_patience, logdir=log_dir)
trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=val_dataloader)
if os.getenv("LOCAL_RANK", '0') == '0':
model_file = os.path.join(log_dir, 'lightning_logs/', 'version_0', 'checkpoints')
model_file = glob.glob(model_file + '/*.ckpt')[0]
loaded_model = TemporalFusionTransformer.load_from_checkpoint(model_file, map_location=torch.device('cpu'))
pickle_path = model_file = os.path.join(log_dir, 'lightning_logs/', 'version_0', 'model_save.pkl')
with open(pickle_path, 'wb') as f:
pickle.dump(loaded_model, f)
Here is another solution.
In this case you initialise the TFT model from the inference feature dataset, then you replace the internals of the model with the trained model parameters, removing the loss and logging_metrics since these are the components that retain their cuda dependency no matter what you do.
def _deserialize(
model_path: str | IO,
data: pd.DataFrame
) -> TemporalFusionTransformer:
LOGGER.info("Deserializing model object at path %s", model_path)
# Load the checkpoint
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
# Ensure the checkpoint contains the state_dict
if "state_dict" not in checkpoint:
raise ValueError("Checkpoint does not contain a state_dict.")
# Extract hyperparameters and dataset parameters
hyperparameters = checkpoint["hyper_parameters"]
## You can either:
## Replace CUDA QuantileLoss with a new CPU instance, do this for logging_metrics too
# if "loss" in hyperparameters and hasattr(hyperparameters["loss"], "quantiles"):
# quantiles = hyperparameters["loss"].quantiles
# # Create a fresh QuantileLoss instance on CPU
# from pytorch_forecasting.metrics import QuantileLoss
# hyperparameters["loss"] = QuantileLoss(quantiles=quantiles)
## OR
## Clear metrics which are not needed for inference and contain GPU dependencies
metrics_to_remove = ['loss', 'logging_metrics']
for metric in metrics_to_remove:
if metric in hyperparameters:
hyperparameters.pop(metric, None)
dataset_parameters = checkpoint["dataset_parameters"]
# If a `TemporalFusionTransformer` was trained while weighing certain data
# points, it still expects at inference time the weight column to be included
# in the dataset even though it has no purpose or effect.
dataset_parameters.pop("weight", None)
# Initialise the dataset and model with the correct parameters
dataset = TimeSeriesDataSet.from_parameters(dataset_parameters, data)
tft = TemporalFusionTransformer.from_dataset(dataset, **hyperparameters)
# Load the state_dict to the model
tft.load_state_dict(checkpoint["state_dict"])
return tft
pickle.dump
Thanks for the help. I tried this method to export the pkl file. When I load the file, the error occur TypeError: 'TemporalFusionTransformer' object is not subscriptable
Here is another solution.
In this case you initialise the TFT model from the inference feature dataset, then you replace the internals of the model with the trained model parameters, removing the
lossandlogging_metricssince these are the components that retain theircudadependency no matter what you do.def _deserialize( model_path: str | IO, data: pd.DataFrame ) -> TemporalFusionTransformer: LOGGER.info("Deserializing model object at path %s", model_path) # Load the checkpoint checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) # Ensure the checkpoint contains the state_dict if "state_dict" not in checkpoint: raise ValueError("Checkpoint does not contain a state_dict.") # Extract hyperparameters and dataset parameters hyperparameters = checkpoint["hyper_parameters"] ## You can either: ## Replace CUDA QuantileLoss with a new CPU instance, do this for logging_metrics too # if "loss" in hyperparameters and hasattr(hyperparameters["loss"], "quantiles"): # quantiles = hyperparameters["loss"].quantiles # # Create a fresh QuantileLoss instance on CPU # from pytorch_forecasting.metrics import QuantileLoss # hyperparameters["loss"] = QuantileLoss(quantiles=quantiles) ## OR ## Clear metrics which are not needed for inference and contain GPU dependencies metrics_to_remove = ['loss', 'logging_metrics'] for metric in metrics_to_remove: if metric in hyperparameters: hyperparameters.pop(metric, None) dataset_parameters = checkpoint["dataset_parameters"] # If a `TemporalFusionTransformer` was trained while weighing certain data # points, it still expects at inference time the weight column to be included # in the dataset even though it has no purpose or effect. dataset_parameters.pop("weight", None) # Initialise the dataset and model with the correct parameters dataset = TimeSeriesDataSet.from_parameters(dataset_parameters, data) tft = TemporalFusionTransformer.from_dataset(dataset, **hyperparameters) # Load the state_dict to the model tft.load_state_dict(checkpoint["state_dict"]) return tft
Thanks. When I use the method, model =TemporalFusionTransformer.load_from_checkpoint('tft.pkl',map_location=torch.device('cpu'))the error occur: IndexError: index 3 is out of bounds for dimension 1 with size 1 Do you have this too?
Hi @mrgreen3325
tft.pkl is a pickle object not a checkpoint. Checkpoints are .ckpt files automatically created during training. by defaul they're in a lightning_logs directory and have names like epoch=0-step=768.ckpt
Just to be clear @mrgreen3325, once you've created the pickled TFT (tft.pkl file) you load it back in like:
with open('tft.pkl','rb') as f: model = pickle.load(f)
Hi @mrgreen3325
tft.pklis a pickle object not a checkpoint. Checkpoints are.ckptfiles automatically created during training. by defaul they're in alightning_logsdirectory and have names likeepoch=0-step=768.ckpt
Thanks, may i know which function I need to use to open the pkl? torch.load?
Hi @mrgreen3325
tft.pklis a pickle object not a checkpoint. Checkpoints are.ckptfiles automatically created during training. by defaul they're in alightning_logsdirectory and have names likeepoch=0-step=768.ckptThanks, may i know which function I need to use to open the pkl? torch.load?
import pickle
pickle stores the entire TFT object instance (or whatever object you pickle) no reference to the original packages required
Hi @mrgreen3325
tft.pklis a pickle object not a checkpoint. Checkpoints are.ckptfiles automatically created during training. by defaul they're in alightning_logsdirectory and have names likeepoch=0-step=768.ckptThanks, may i know which function I need to use to open the pkl? torch.load?
import pickle
pickle stores the entire TFT object instance (or whatever object you pickle) no reference to the original packages required
Thanks for reply. May I know that how to set the tsdataset = TimeSeriesDataSet()? Since the data are encoded during training, is it need to encoded in the same way for the test set during inference? Thanks.
Thanks for reply. May I know that how to set the tsdataset = TimeSeriesDataSet()? Since the data are encoded during training, is it need to encoded in the same way for the test set during inference? Thanks.
This is the same for any TFT, I would recommend looking at some of the great example uses that the pytorch-forecasting devs have created.
E.g. https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/stallion.html