pytorch-forecasting
pytorch-forecasting copied to clipboard
cannot convert GPU trained/saved model to CPU inference
- Python version: Python 3.8.12
- Operating System: Linux AWS
lightning==2.0.1.post0 lightning-cloud==0.5.33 lightning-utilities==0.8.0 numpy==1.24.2 pandas==1.5.3 pyarrow==11.0.0 pytorch-forecasting==1.0.0 pytorch-lightning==2.0.1.post0 pytorch-optimizer==2.5.2 torch==2.0.0 torchmetrics==0.11.4
Expected behavior
device = torch.device('cpu')
model = TemporalFusionTransformer.load_from_checkpoint(best_model_path, map_location=map_location)
model.device
changed correctly to 'cpu' but
device(type='cpu')
calculate mean absolute error on validation set
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = model.predict(val_dataloader)
(actuals - predictions).abs().mean()
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
43 else:
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py:847, in Trainer._predict_impl(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
844 ckpt_path = self._checkpoint_connector._select_ckpt_path(
845 self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
846 )
--> 847 results = self._run(model, ckpt_path=ckpt_path)
849 assert self.state.stopped
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py:911, in Trainer._run(self, model, ckpt_path)
910 # strategy will configure model and move it to the device
--> 911 self.strategy.setup(self)
913 # hook
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/single_device.py:73, in SingleDeviceStrategy.setup(self, trainer)
72 def setup(self, trainer: pl.Trainer) -> None:
---> 73 self.model_to_device()
74 super().setup(trainer)
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/single_device.py:70, in SingleDeviceStrategy.model_to_device(self)
69 assert self.model is not None, "self.model must be set before self.model.to()"
---> 70 self.model.to(self.root_device)
File /usr/local/lib/python3.8/site-packages/lightning/fabric/utilities/device_dtype_mixin.py:54, in _DeviceDtypeModuleMixin.to(self, *args, **kwargs)
53 self.__update_properties(device=device, dtype=dtype)
---> 54 return super().to(*args, **kwargs)
File /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1145, in Module.to(self, *args, **kwargs)
1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1145 return self._apply(convert)
File /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
796 for module in self.children():
--> 797 module._apply(fn)
799 def compute_should_use_set_data(tensor, tensor_applied):
File /usr/local/lib/python3.8/site-packages/torchmetrics/metric.py:659, in Metric._apply(self, fn)
657 # make sure to update the device attribute
658 # if the dummy tensor moves device by fn function we should also update the attribute
--> 659 self._device = fn(torch.zeros(1, device=self.device)).device
661 # Additional apply to forward cache and computed attributes (may be nested)
File /usr/local/lib/python3.8/site-packages/torch/cuda/__init__.py:247, in _lazy_init()
246 os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
--> 247 torch._C._cuda_init()
248 # Some of the queued calls may reentrantly call _lazy_init();
249 # we need to just return without initializing in that case.
250 # However, we must not let any *other* threads in!
RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[122], line 5
2 actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
3 # actuals = actuals.to(device)
----> 5 predictions = model.predict(val_dataloader)
6 (actuals - predictions).abs().mean()
File /usr/local/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py:1423, in BaseModel.predict(self, data, mode, return_index, return_decoder_lengths, batch_size, num_workers, fast_dev_run, return_x, return_y, mode_kwargs, trainer_kwargs, write_interval, output_dir, **kwargs)
1421 logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
1422 trainer = Trainer(fast_dev_run=fast_dev_run, **trainer_kwargs)
-> 1423 trainer.predict(self, dataloader)
1424 logging.getLogger("lightning").setLevel(log_level_lighting)
1425 logging.getLogger("pytorch_lightning").setLevel(log_level_pytorch_lightning)
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py:805, in Trainer.predict(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
803 model = _maybe_unwrap_optimized(model)
804 self.strategy._lightning_module = model
--> 805 return call._call_and_handle_interrupt(
806 self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
807 )
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py:68, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
66 for logger in trainer.loggers:
67 logger.finalize("failed")
---> 68 trainer._teardown()
69 # teardown might access the stage so we reset it after
70 trainer.state.stage = None
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py:958, in Trainer._teardown(self)
955 def _teardown(self) -> None:
956 """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
957 Callback; those are handled by :meth:`_call_teardown_hook`."""
--> 958 self.strategy.teardown()
959 loop = self._active_loop
960 # loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
File /usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py:475, in Strategy.teardown(self)
473 if self.lightning_module is not None:
474 log.debug(f"{self.__class__.__name__}: moving model to CPU")
--> 475 self.lightning_module.cpu()
476 self.precision_plugin.teardown()
477 assert self.accelerator is not None
File /usr/local/lib/python3.8/site-packages/lightning/fabric/utilities/device_dtype_mixin.py:78, in _DeviceDtypeModuleMixin.cpu(self)
76 """See :meth:`torch.nn.Module.cpu`."""
77 self.__update_properties(device=torch.device("cpu"))
---> 78 return super().cpu()
File /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:954, in Module.cpu(self)
945 def cpu(self: T) -> T:
946 r"""Moves all model parameters and buffers to the CPU.
947
948 .. note::
(...)
952 Module: self
953 """
--> 954 return self._apply(lambda t: t.cpu())
File /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
795 def _apply(self, fn):
796 for module in self.children():
--> 797 module._apply(fn)
799 def compute_should_use_set_data(tensor, tensor_applied):
800 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
801 # If the new tensor has compatible tensor type as the existing tensor,
802 # the current behavior is to change the tensor in-place using `.data =`,
(...)
807 # global flag to let the user control whether they want the future
808 # behavior of overwriting the existing tensor or not.
File /usr/local/lib/python3.8/site-packages/torchmetrics/metric.py:659, in Metric._apply(self, fn)
653 raise TypeError(
654 "Expected metric state to be either a Tensor" f"or a list of Tensor, but encountered {current_val}"
655 )
657 # make sure to update the device attribute
658 # if the dummy tensor moves device by fn function we should also update the attribute
--> 659 self._device = fn(torch.zeros(1, device=self.device)).device
661 # Additional apply to forward cache and computed attributes (may be nested)
662 if this._computed is not None:
File /usr/local/lib/python3.8/site-packages/torch/cuda/__init__.py:247, in _lazy_init()
245 if 'CUDA_MODULE_LOADING' not in os.environ:
246 os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
--> 247 torch._C._cuda_init()
248 # Some of the queued calls may reentrantly call _lazy_init();
249 # we need to just return without initializing in that case.
250 # However, we must not let any *other* threads in!
251 _tls.is_initializing = True
RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
Having a similar problem with TemporalFusionTransformer
and QuantileLoss
. I can see that QuantileLoss._device
is set to cuda:0
even though the TFT model was mapped to CPU using
TemporalFusionTransformer.load_from_checkpoint(best_model_path, map_location='cpu')
as in OP's post.
Is that working:
model = TemporalFusionTransformer.load_from_checkpoint(best_model_path) predictions = model.predict(test).cpu()
?
No. I'm not sure why, but in the meantime, not even TemporalFusionTransformer.load_from_checkpoint(best_model_path)
works. It might be because I upgraded to Pytorch-Lightning 2.0.2.
It now throws an error when calling _load_from_checkpoint
in this line:
AssertionError: Torch not compiled with CUDA enabled
I am able to hack around this by adding the following between lines 89 and 90:
storage.loss._device = "cpu"
for metric in storage.logging_metrics:
metric._device = "cpu"
So manually patching the _device fields containing "cuda:0" to "cpu" in the model loaded from the checkpoint file helps to resolve this problem. But I think loading from a checkpoint should work regardless of whether the checkpoint was created on a GPU or not.
it might be worth mentioning there are these two warning in the log when trying to load a model from a checkpoint:
UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
Would ignoring these two attributes resolve the error we are seeing or is this unrelated?
Are there any updates on this?
I am able to hack around this by adding the following between lines 89 and 90:
This worked for me too. Thanks @jurgispods