Laplace
Laplace copied to clipboard
Library breaks double precision
I noticed that FullLaplace
produces float results even when only objects of double precision are being used. Here is a quick example
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import laplace
dtype = torch.float64
X = torch.randn((100,3), dtype=dtype)
Y = torch.randn((100,3), dtype=dtype)
data = TensorDataset(X,Y)
dataloader = DataLoader(data, batch_size=10)
model = nn.Linear(3,3, dtype=dtype)
full_la = laplace.Laplace(model=model, subset_of_weights='all',
likelihood='regression', hessian_structure='full')
full_la.fit(dataloader)
print(full_la.H.dtype) # prints torch.float32 (at least on my machine)
Interestingly when using KFAC instead of the full hessian (i.e. when full
above is replaced by kron
) the hessian la.H.to_matrix()
is of dtype float.64
.