backpack
backpack copied to clipboard
fix: ScaleModule and SumModule for DiagHessian.
Partially fixes #316 . ScaleModule is also used for torch.nn.Identity
.
Not sure if hessian_is_zero()
should always return True
for those two modules.
Same with accumulate_backpropagated_quantities()
which concats dicts
instead of tensors
as for the DiagGGN
.
Commit 74c41735a12b72861e17e5ed4c2b0a97d40283c0 allows to compute the Hessian diagonal even if there is a batch norm in the network by simply not computing the Hessian elements for the batch norm layer.
Not sure if this is a reasonable approach, however, this can be used as a quick fix.
The other diagonal elements can then be extracted as following:
hessian_diag_wo_bn = torch.cat(
[
p.diag_h_batch.view(batch.shape[0], -1)
for p in model.parameters()
if "diag_h_batch" in p.__dict__.keys()
],
dim=1,
)
Commit 48f03e92e12fd97970a78335fb645ac5ee9a77f2 tries to actually compute the diagonal elements of the Hessian for the batch norm layer.
If one of you could have a quick look at the commits to see if they make any sense, would really appreciate it. @f-dangel @fKunstner
Thank you!
Hi,
just wanted to let you know I read your message above. Please don't expect any reaction before the ICLR deadline (Sep 28)