backpack icon indicating copy to clipboard operation
backpack copied to clipboard

Missing implementation of supported layers for DiagHessian and BatchDiagHessian

Open hlzl opened this issue 1 year ago • 0 comments

There are multiple layers which are specified as being supported for second order derivatives that actually do not work when trying to calculate the Hessian diagonal using backpack-for-pytorch<=1.6.0.

So far, I've run into this problem with the following layers:

  • [ ] backpack.custom_module.branching.ScaleModule, torch.nn.Identity
  • [ ] torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
  • [ ] backpack.custom_module.branching.SumModule

This can be tested with a script such as the following:

import torch
from backpack import backpack, extend
from backpack.extensions import DiagHessian, BatchDiagHessian
from backpack.custom_module.branching import Parallel, SumModule

model = extend(
    torch.nn.Sequential(
        *[
            torch.nn.Conv2d(3, 16, kernel_size=(3, 3)),
            Parallel(
                torch.nn.Identity(), torch.nn.BatchNorm2d(16), merge_module=SumModule()
            ),
            torch.nn.AdaptiveAvgPool2d(output_size=1),
            torch.nn.Flatten(),
            torch.nn.Linear(16, 2),
        ]
    ).cuda()
)
criterion = extend(torch.nn.CrossEntropyLoss())

batch = torch.randn((2, 3, 8, 8)).cuda()
target = torch.tensor([[1.0, 0.0], [0.0, 1.0]]).cuda()

model.eval()
model.zero_grad()
loss = criterion(model(batch), target)

with backpack(DiagHessian(), BatchDiagHessian()):
    loss.backward()

hessian_diag = torch.cat(
    [p.diag_h.view(-1) for p in model.parameters()], dim=1
)
hessian_diag_batch = torch.cat(
    [p.diag_h_batch.view(batch.shape[0], -1) for p in model.parameters()], dim=1
)

I'm guessing that these require independent fixes, but think it is a good idea to collect all layers with missing support summarised here.

hlzl avatar Sep 05 '23 13:09 hlzl