Laplace
Laplace copied to clipboard
Serialization: `torch.load(..., map_location="cpu")` doesn't map all tensors in `Laplace` instance
Hi
I came across an issue when serializing a BaseLaplace
-derived object with tensors mapped to GPU and then reading it back in and mapping to CPU only. The same goes for a data structure containing such an object. See below for details on the use case.
When using torch.load(..., map_location="cpu")
in
import dill
lap = Laplace(
model=model.to("cuda"),
likelihood="regression",
...
)
lap.fit(train_dl)
torch.save(lap, "lap_gpu.pt", pickle_module=dill)
lap = torch.load("lap_gpu.pt", map_location="cpu")
one assumes that all tensors are mapped to CPU. This is true for weight tensors in lap.model
, for instance. However, some tensor-valued properties in Laplace
, such as prior_precision_diag
, are calculated each time when accessed, and some of them access self._device
for that. The latter is defined once in the BaseLaplace
constructor
self._device = next(model.parameters()).device
to be the same device as the one model
lives on. torch.load(..., map_location="cpu")
won't affect self._device
and thus for all properties which use self._device
, an attempt to calculate them will mix GPU and CPU tensors, which will fail.
One solution is to make _device
a property.
@property
def _device(self):
return next(self.model.parameters()).device
I have a branch with this working. Before I open a PR, I'd like to know whether you think this is useful and if there are other implications that I did not encounter so far (test suite passes, of course). At least in terms of performance, this should be ok:
>>> %timeit next(lap.model.parameters()).device
1.87 µs ± 2.73 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
Here are the combinations of hessian_structure
and subset_of_weights
that are affected (load_ok=no
) or not (load_ok=yes
).
hessian_structure subset_of_weights err load_ok
diag last_layer ValueError: Jacobians need to be on the same device as Laplace. no
kron last_layer yes
full last_layer yes
diag all ValueError: Jacobians need to be on the same device as Laplace. no
lowrank all RuntimeError: No HIP GPUs are available no
kron all yes
full all RuntimeError: No HIP GPUs are available no
full subnetwork RuntimeError: No HIP GPUs are available no
With the proposed fix, all of these work. I can provide more details and code to reproduce the above if needed.
The use case is serializing a fitted Laplace
instance, including self.model
, to disk and avoid the usual torch load_state_dict
gymnastics (I'm aware of #148, but for small enough models and many test calculations, serializing stuff to one blob and loading it back in is more convenient). I may then want to run prediction later on other machines (in a cluster) by loading a fitted Laplace
from disk and where CPU execution is fast enough. Disclaimer: I'm aware that serializing complex objects is not the best of practices because it has lots of potential pitfalls, and that this is not recommended for long-term data storage.
Still, I think it would be nice if Laplace
instances behaved transparently when using torch.load()
as torch.Module
s do.
Thanks.
Thanks, Steve. It seems like a neat addition to #148 and I don't see any downside to this. Feel free to open a pull request!
Anyway, I'm just curious about the use cases for this. If I understand correctly, your proposal will mainly save some efforts in terms of matching the Laplace
specification (i.e., the correct hessian_structure
, subset_of_weights
, etc) between different scripts, e.g., if one has two different long-running processes. Then using this serialization approach, one only needs to change the spec above in one process's script.
Is there any other use cases / advantages?
Thanks for the feedback. I'll prepare a PR then.
The proposed fix is orthogonal to any chosen workflow, but since you asked, I'm happy to elaborate on the reasoning behind using torch.load()
+ torch.save()
.
There are two ways in plain torch to save/load trained models, both have their pros and cons: the state_dict
way or the pickle
way (torch.save()
+ torch.load()
whole objects). #148 implements the state_dict
way for Laplace
.
Both ways exist for good reasons. In my view the state_dict
way is optimal for long term trained model storage and/or sharing of models. This comes at the cost of having to manually store metadata, such as constructor arguments for model
and Laplace
. Lightning checkpoints automate this away very nicely for Modules / models. There is also ONNX but I haven't used this in production so far. The pickle
way is, as I see it, a convenient quick way to dump a trained model (and indeed a Laplace
instance containing one) to disk when iterating on code, running experiments or debugging, keeping in mind that it is more fragile (e.g. things won't load if you refactor code and rename modules between save and load).
The proposed fix increases torch compatibility in the pickle
way for the specific case of using map_location
: train on a machine with one device (e.g. GPU), load on another without such as device:
# machine with GPU
torch.save(lap, "lap_gpu.pt")
# machine without GPU
lap = torch.load("lap_gpu.pt", map_location="cpu")
This works for torch Tensors and Modules, but not for Laplace
instances. This fix makes this possible by making lap._device
respect map_location
.
For other cases -- save and load on the same machine / a machine with matching hardware (e.g. both machines have a GPU, doesn't even need to be the same model), things already work.
# machine with GPU
torch.save(lap, "lap_gpu.pt")
# other machine with GPU
lap = torch.load("lap_gpu.pt")
I hope this answers your question.