backpack icon indicating copy to clipboard operation
backpack copied to clipboard

fix: ScaleModule and SumModule for DiagHessian.

Open hlzl opened this issue 1 year ago • 3 comments

Partially fixes #316 . ScaleModule is also used for torch.nn.Identity.

Not sure if hessian_is_zero() should always return True for those two modules. Same with accumulate_backpropagated_quantities() which concats dicts instead of tensors as for the DiagGGN.

hlzl avatar Sep 05 '23 14:09 hlzl

Commit 74c41735a12b72861e17e5ed4c2b0a97d40283c0 allows to compute the Hessian diagonal even if there is a batch norm in the network by simply not computing the Hessian elements for the batch norm layer.

Not sure if this is a reasonable approach, however, this can be used as a quick fix.

The other diagonal elements can then be extracted as following:

hessian_diag_wo_bn = torch.cat(
    [
        p.diag_h_batch.view(batch.shape[0], -1)
        for p in model.parameters()
        if "diag_h_batch" in p.__dict__.keys()
    ],
    dim=1,
)

hlzl avatar Sep 05 '23 15:09 hlzl

Commit 48f03e92e12fd97970a78335fb645ac5ee9a77f2 tries to actually compute the diagonal elements of the Hessian for the batch norm layer.

If one of you could have a quick look at the commits to see if they make any sense, would really appreciate it. @f-dangel @fKunstner

Thank you!

hlzl avatar Sep 05 '23 17:09 hlzl

Hi,

just wanted to let you know I read your message above. Please don't expect any reaction before the ICLR deadline (Sep 28)

f-dangel avatar Sep 14 '23 14:09 f-dangel