Laplace icon indicating copy to clipboard operation
Laplace copied to clipboard

Regression covariance is only diagonal, with the same value across it

Open ArturPrzybysz opened this issue 2 years ago • 6 comments

I have created a last layer BNN with your package. I used "kron" and "diag" hessian structures in regression task. However, just as stated in the title, the covariance matrix diagonal has the same, single value.

Is this expected behavior?

I could provide a minimal code example if it is not expected and you suspect an error in implementation.

ArturPrzybysz avatar May 21 '22 14:05 ArturPrzybysz

Hi Artur, thanks for pointing this issue out. I would be great if you could provide a minimal example to reproduce this so we can look into it. Mathematically this should not happen, or at least only in very special cases.

aleximmer avatar May 25 '22 09:05 aleximmer

@AlexImmer Thank you for the response! The example I provide is not minimal, but pictures the situation well enough. It is optional to train the MAP model from scratch, I also provide a pretrained model state_dict to fit Laplace Approximation to.

The code is here: https://github.com/ArturPrzybysz/LaplaceConvDemo

They starting file is src/main.py , where the predictions are visualized, also the diagonals of the covariance matrix is printed and asserted to consist of a single value.

ArturPrzybysz avatar May 26 '22 17:05 ArturPrzybysz

Hi Artur, it seems that there is indeed an issue with the glm predictive. When I change this line in your main.py:

pred = la_model.la(X)

into

pred = la_model.la(X, pred_type='nn', n_samples=10)

the variances don't have the same values. We'll investigate this further, but for now you can use the nn predictive which should work just fine (make sure to increase n_samples, though).

wiseodd avatar May 27 '22 16:05 wiseodd

Sweet! Thank you for the help.

ArturPrzybysz avatar May 27 '22 18:05 ArturPrzybysz

@ArturPrzybysz One other thing I randomly noticed when briefly looking at your code was that the tuning of the prior precision is probably not working as you intended: you pass the validation loader, but keep the default method='marglik' argument, so the validation loader will not be used at all. If you want to use it, you will have to set method='CV' and also change the loss argument, which defaults to the cross entropy loss, whereas you probably want to use something like the MSE loss for regression. We will add additional checks to avoid this unintended behavior. Moreover, you should also specify the pred_type and link_approx arguments, to make them consistent with the settings you are using for prediction later on.

runame avatar May 27 '22 20:05 runame

@runame Wow, thank you!

ArturPrzybysz avatar May 28 '22 08:05 ArturPrzybysz

Revisiting this issue using the attached quick script.

This problem happens in last-layer Laplace (all-layer is fine), for any Hessian structures, with the following backends:

  • AsdlGGN
  • BackPackGGN
  • CurvlinopsEF
  • CurvlinopsGGN
  • CurvlinopsHessian
from laplace import Laplace, ParametricLaplace
from laplace.curvature import (
    CurvlinopsGGN,
    CurvlinopsEF,
    CurvlinopsHessian,
    AsdlGGN,
    AsdlEF,
    AsdlHessian,
    BackPackGGN,
    BackPackEF,
)
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

model = nn.Sequential(nn.Linear(3, 10), nn.ReLU(), nn.Linear(10, 5))
trainloader = DataLoader(
    TensorDataset(torch.randn(16, 3), torch.randn((16, 5))), batch_size=3
)
testloader = DataLoader(
    TensorDataset(torch.randn(7, 3), torch.randn((7, 5))), batch_size=3
)
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)

for _ in range(100):
    for x, y in trainloader:
        opt.zero_grad()
        out = model(x)
        loss = F.mse_loss(out, y)
        loss.backward()
        opt.step()

la = Laplace(
    model,
    likelihood="regression",
    subset_of_weights="last_layer",
    hessian_structure="kron",
    backend=CurvlinopsGGN,
)
la.fit(trainloader)
# la.optimize_prior_precision()

for x, _ in testloader:
    pred_mean, pred_cov = la(x)
    pred_var = torch.diagonal(pred_cov, dim1=-2, dim2=-1)
    print(pred_var)

wiseodd avatar Apr 29 '24 16:04 wiseodd

This is actually a limitation of Bayesian linear regression with an isotropic Gaussian prior in general.

Let $f(x) = W \phi(x)$ be the model where $f(x) \in \mathbb{R}^c$, $W \in \mathbb{R}^{c \times d}$, and $\phi(x) \in \mathbb{R}^d$. Let $\mathrm{vec}(W) \sim \mathcal{N}(0, \sigma_0^2 I_{cd \times cd})$.

Given a dataset $\mathcal{D} = { (x_i, y_i) }_{i=1}^n$, the (exact) Hessian (equiv. the GGN, Fisher) is:

$$ H = \sum_{i=1}^n (\phi(x) \otimes I_{c \times c}) I_{c \times c} (\phi(x) \otimes I_{c \times c})^\top , $$

where the middle matrix is the Hessian of the MSE loss wrt. $f(x)$.

Notice the Kronecker structure: $H$ is block diagonal, consisting of $c$ blocks, each block is identical with other blocks---they all arise from $\sum_i \phi(x_i) \phi(x_i)^\top$. This implies that the diagonal of $H$ also has a repeating block structure. This further implies that the posterior covariance $\Sigma = (H + \sigma_0^2 I)^{-1}$ also has the same structure.

Now, when making a prediction, we compute the predictive covariance $(\phi(x_*) \otimes I_{c \times c}) \Sigma (\phi(x_*) \otimes I_{c \times c})^\top$. Notice that we have block structures on all the matrices, implying that the $c \times c$ predictive covariance is diagonal with the same value at each component coordinates.

Note that this happens universally for GGN/Hessian, doesn't matter the factorization (diag, kron, full). It's also clear as to why EF doesn't have this issue (albeit implies the usage of the "wrong" Hessian).

The fix

This limitation is due to the linear model itself under the isotropic prior. So, you can simply move away from that prior. In Laplace, this is easy:

la = Laplace(model, "regression", subset_of_weights="last_layer", ...)
la.optimize_prior_precision(prior_structure="diag")

wiseodd avatar Apr 29 '24 18:04 wiseodd

Sorry @ArturPrzybysz for taking 2 years to answer this issue!

wiseodd avatar Apr 29 '24 18:04 wiseodd

@wiseodd no problem, this package helped me a ton with my MSc thesis anyway!

ArturPrzybysz avatar Apr 29 '24 19:04 ArturPrzybysz