Laplace
Laplace copied to clipboard
Library breaks double precision
I noticed that FullLaplace produces float results even when only objects of double precision are being used. Here is a quick example
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import laplace
dtype = torch.float64
X = torch.randn((100,3), dtype=dtype)
Y = torch.randn((100,3), dtype=dtype)
data = TensorDataset(X,Y)
dataloader = DataLoader(data, batch_size=10)
model = nn.Linear(3,3, dtype=dtype)
full_la = laplace.Laplace(model=model, subset_of_weights='all',
likelihood='regression', hessian_structure='full')
full_la.fit(dataloader)
print(full_la.H.dtype) # prints torch.float32 (at least on my machine)
Interestingly when using KFAC instead of the full hessian (i.e. when full above is replaced by kron) the hessian la.H.to_matrix() is of dtype float.64.
Thanks @joemrt! I confirm via this test https://github.com/aleximmer/Laplace/commit/c182504070e3ae47e3ba374a009846f9b5c12f46 that this is indeed an unintended behavior.
Great, thanks @wiseodd! On first glance it appears to me that you forgot to pass the backend to laplace in the test