Laplace icon indicating copy to clipboard operation
Laplace copied to clipboard

Library breaks double precision

Open joemrt opened this issue 5 months ago • 2 comments

I noticed that FullLaplace produces float results even when only objects of double precision are being used. Here is a quick example

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import laplace

dtype = torch.float64
X = torch.randn((100,3), dtype=dtype)
Y = torch.randn((100,3), dtype=dtype)
data = TensorDataset(X,Y)
dataloader = DataLoader(data, batch_size=10)
model = nn.Linear(3,3, dtype=dtype)

full_la = laplace.Laplace(model=model, subset_of_weights='all',
        likelihood='regression', hessian_structure='full')
full_la.fit(dataloader)
print(full_la.H.dtype) # prints torch.float32 (at least on my machine) 

Interestingly when using KFAC instead of the full hessian (i.e. when full above is replaced by kron) the hessian la.H.to_matrix() is of dtype float.64.

joemrt avatar Sep 18 '24 15:09 joemrt