Laplace icon indicating copy to clipboard operation
Laplace copied to clipboard

Serialization: `torch.load(..., map_location="cpu")` doesn't map all tensors in `Laplace` instance

Open elcorto opened this issue 11 months ago • 1 comments

Hi

I came across an issue when serializing a BaseLaplace-derived object with tensors mapped to GPU and then reading it back in and mapping to CPU only. The same goes for a data structure containing such an object. See below for details on the use case.

When using torch.load(..., map_location="cpu") in

import dill

lap = Laplace(
    model=model.to("cuda"),
    likelihood="regression",
    ...
)
lap.fit(train_dl)
torch.save(lap, "lap_gpu.pt", pickle_module=dill)

lap = torch.load("lap_gpu.pt", map_location="cpu")

one assumes that all tensors are mapped to CPU. This is true for weight tensors in lap.model, for instance. However, some tensor-valued properties in Laplace, such as prior_precision_diag, are calculated each time when accessed, and some of them access self._device for that. The latter is defined once in the BaseLaplace constructor

self._device = next(model.parameters()).device

to be the same device as the one model lives on. torch.load(..., map_location="cpu") won't affect self._device and thus for all properties which use self._device, an attempt to calculate them will mix GPU and CPU tensors, which will fail.

One solution is to make _device a property.

@property
def _device(self):
    return next(self.model.parameters()).device

I have a branch with this working. Before I open a PR, I'd like to know whether you think this is useful and if there are other implications that I did not encounter so far (test suite passes, of course). At least in terms of performance, this should be ok:

>>> %timeit next(lap.model.parameters()).device
1.87 µs ± 2.73 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

Here are the combinations of hessian_structure and subset_of_weights that are affected (load_ok=no) or not (load_ok=yes).

hessian_structure subset_of_weights                                                             err load_ok
             diag        last_layer ValueError: Jacobians need to be on the same device as Laplace.      no
             kron        last_layer                                                                     yes
             full        last_layer                                                                     yes
             diag               all ValueError: Jacobians need to be on the same device as Laplace.      no
          lowrank               all                         RuntimeError: No HIP GPUs are available      no
             kron               all                                                                     yes
             full               all                         RuntimeError: No HIP GPUs are available      no
             full        subnetwork                         RuntimeError: No HIP GPUs are available      no

With the proposed fix, all of these work. I can provide more details and code to reproduce the above if needed.

The use case is serializing a fitted Laplace instance, including self.model, to disk and avoid the usual torch load_state_dict gymnastics (I'm aware of #148, but for small enough models and many test calculations, serializing stuff to one blob and loading it back in is more convenient). I may then want to run prediction later on other machines (in a cluster) by loading a fitted Laplace from disk and where CPU execution is fast enough. Disclaimer: I'm aware that serializing complex objects is not the best of practices because it has lots of potential pitfalls, and that this is not recommended for long-term data storage.

Still, I think it would be nice if Laplace instances behaved transparently when using torch.load() as torch.Modules do.

Thanks.

elcorto avatar Mar 23 '24 09:03 elcorto

Thanks, Steve. It seems like a neat addition to #148 and I don't see any downside to this. Feel free to open a pull request!

Anyway, I'm just curious about the use cases for this. If I understand correctly, your proposal will mainly save some efforts in terms of matching the Laplace specification (i.e., the correct hessian_structure, subset_of_weights, etc) between different scripts, e.g., if one has two different long-running processes. Then using this serialization approach, one only needs to change the spec above in one process's script.

Is there any other use cases / advantages?

wiseodd avatar Mar 23 '24 14:03 wiseodd

Thanks for the feedback. I'll prepare a PR then.

The proposed fix is orthogonal to any chosen workflow, but since you asked, I'm happy to elaborate on the reasoning behind using torch.load() + torch.save().

There are two ways in plain torch to save/load trained models, both have their pros and cons: the state_dict way or the pickle way (torch.save() + torch.load() whole objects). #148 implements the state_dict way for Laplace.

Both ways exist for good reasons. In my view the state_dict way is optimal for long term trained model storage and/or sharing of models. This comes at the cost of having to manually store metadata, such as constructor arguments for model and Laplace. Lightning checkpoints automate this away very nicely for Modules / models. There is also ONNX but I haven't used this in production so far. The pickle way is, as I see it, a convenient quick way to dump a trained model (and indeed a Laplace instance containing one) to disk when iterating on code, running experiments or debugging, keeping in mind that it is more fragile (e.g. things won't load if you refactor code and rename modules between save and load).

The proposed fix increases torch compatibility in the pickle way for the specific case of using map_location: train on a machine with one device (e.g. GPU), load on another without such as device:

# machine with GPU
torch.save(lap, "lap_gpu.pt")
# machine without GPU
lap = torch.load("lap_gpu.pt", map_location="cpu")

This works for torch Tensors and Modules, but not for Laplace instances. This fix makes this possible by making lap._device respect map_location.

For other cases -- save and load on the same machine / a machine with matching hardware (e.g. both machines have a GPU, doesn't even need to be the same model), things already work.

# machine with GPU
torch.save(lap, "lap_gpu.pt")
# other machine with GPU
lap = torch.load("lap_gpu.pt")

I hope this answers your question.

elcorto avatar Mar 24 '24 13:03 elcorto