backpack
backpack copied to clipboard
KFAC support in BatchNorm (eval mode)
Hi,
Thanks for the repo! This is really a nice work. I am planning to calculate the KFAC with backpack. But it raises the following error:
NotImplementedError: Extension saving to kfac does not have an extension for Module <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
My network is as follows:
model = nn.Sequential(
nn.Conv2d(1, 8, 3, stride=3),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.Conv2d(8, 4, 3, stride=3),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.Flatten(),
nn.Linear(36, 10))
loss = nn.CrossEntropyLoss()
When calculating the KFAC with:
model_ = extend(model.eval())
logits = model_(X)
loss = extend(loss_func)(logits, Y)
with backpack(KFAC(mc_samples=1000)):
loss.backward()
It raises the not implemented error. I am wondering whether calculating KFAC in a network with BN layers in the middle is supported by backpack? It seems like it should be supported, since it successfully works in ResNet.
Thanks
Hi, thanks for your question!
Just to make sure I'm getting it right: You want to compute KFAC for the Conv2d
and Linear
layers in your network, or do you want to compute KFAC for the parameters of the BatchNorm2d
layer? (For the latter, I'm not sure if KFAC is defined)
Indeed, I want to compute the KFAC for the parameters of Conv2d, Linear and BatchNorm2d layer. Is it possible to achieve this?
BackPACK can compute KFAC
for Linear
and Conv2d
layers, but not for BatchNorm2d
. I don't know how the KFAC papers deal with batch normalization. Do you know? If so, one could implement this missing feature
Sadly, there is no easy way to tell BackPACK to ignore the parameters of batch norm layers, because it tries to compute its quantities on all parameters that have requires_grad=True
.
If you want to get KFAC for the supported layers, you will have to set p.requires_grad=False
for the BN parameters. But then you also won't get their gradient.
Thanks for the prompt reply! Yep, I have not seen some paper discussing the KFAC calculation for BN either...
To get the KFAC for the supported layers (Linear and Conv2d), I found a way to bypass the NotImplementedError.
# Extend hbp/__init__.py by
class HBP(SecondOrderBackpropExtension):
def __init__(
self,
curv_type,
loss_hessian_strategy,
backprop_strategy,
ea_strategy,
savefield="hbp",
):
...
super().__init__(
savefield=savefield,
fail_mode="ERROR",
module_exts={
...
Identity: custom_module.HBPScaleModule(),
BatchNorm2d: batchnorm_nd.HBPBatchNormNd(),
},
)
# The HBPBatchNormNd is defined as
from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule
class HBPBatchNormNd(HBPBaseModule):
def __init__(self):
super().__init__(BatchNormNdDerivatives(), params=None)
With such modification, it works without raising the error. I am not quite sure whether it is the right manner. Do you have any advice?
Hi,
that workaround looks good! Indeed, this will ignore the BN parameters, while keeping BackPACK's backpropagation through the layer for KFAC intact.
That's a great relief! :)
One way to get started on this would be to add support for KFAC in BatchNorm in evaluation mode.
I will outline in the following what needs to be done (this may be technically not 100% accurate).
Pull requests welcome.
Let's assume a BatchNorm1d
layer that takes an input X
of shape [N, C]
and maps it to an output Z
of shape [N, C]
. The parameters γ
and β
are both of shape [C]
. The forward pass (in evaluation mode) is
Z[n, :] = γ ⊙ X[n, :] + β n = 1, ... , N
(where ⊙
is elementwise multiplication). This looks a bit like a Linear
layer with weights W = diag(γ)
and bias b = β
.
We don't really need a Kronecker factorization here, because the curvature blocks for γ
and β
are both of shape [C, C]
. So instead we compute the MC-sampled Fisher/GGN block:
-
Computing the
KFAC
forβ
is like computing the MC-approximated Fisher block forβ
.backpropagated_grads = ... # from BackPACK, has shape [M, N, C] where M denotes the number of MC samples v = backpropagated_grads JTv = v # apply transpose Jacobian (identity in this case) # square the result to get the GGN block kfac_beta = einsum("mnc,mnd->cd", JTv, JTv) # [C, C] return [kfac_beta] # The KFAC extension returns lists with Kronecker factors
-
Computing the
KFAC
foṟγ
is like computing the MC-approximated Fisher block forγ
.backpropagated_grads = ... # from BackPACK, has shape [M, N, C] where M denotes the number of MC samples X = module.input0 v = backpropagated_grads JTv = einsum("mnc,nc->mnc", v, X) # apply transpose Jacobian # square the result to get the GGN block kfac_gamma = einsum("mnc,mnd->cd", JTv, JTv) # [C, C] return [kfac_gamma] # The KFAC extension returns lists with Kronecker factors
-
One can test this by setting
N=1
and checking that thekfac_gamma → GGN(gamma)
andkfac_beta → GGN(beta)
as the number of samples grows (M → ∞
). -
To generalize this to
BatchNormNd
, simply replace"mnc"
by"mnc..."
,"mnd"
by"mnd..."
, and"nc"
by"nc..."
in the aboveeinsum
s -
There should be error messages if the module is not in evaluation mode
Thanks for your guidance! :)
A pull request has been raised. A simple test case is also added to test the result and mode checking.
I want to ask one more question why it is not needed to divide the kfac_gamma by JTv.shape[0]
, which is the number of MC samples in calculating kfac_gamma?
kfac_gamma = einsum("mnc,mnd->cd", JTv, JTv) # [C, C]
Hi,
thanks for the PR; apologies you will have to be patient with my review.
Regarding your question: Good point! The factor 1 / sqrt(C)
where C
is the number of MC samples is inserted by the loss function, which creates the MC-approximated Hessian square root that is then backpropagated through all layers. Squaring that results in the desired 1 / C
.
For CrossEntropyLoss
this happens here in the code (M
denotes the number of MC samples).
Best, Felix
Is there a way to Backpack ignore the modules it does not support? I want to use it with models I did not implement my self (timm models, for example).
Hi,
are you asking this question w.r.t. KFAC?
If you want to use a first-order extension, you can simply extend
the layers that are supported by BackPACK.
If you want to use a second-order extension, all layers must be supported by BackPACK, as otherwise it cannot backpropagate the additional information through the compute graph.
Best, Felix
I mean second-order extension. I want just to have an estimation and I don't care if the value is not precise, I just wanted to check its change during training. Can I simply remove the module if it does not change the dimensions?
Hey,
not sure if I'm following what you exactly want to do. If you remove the BN layers and all layers are supported by BackPACK, you can use second-order extensions. But also, your network will behave differently because you eliminated the BN layers.
Best, Felix
I only use BackPack second-order extension to measure the Fisher Information of the weights during training, but I will not use them in the training. Every n steps before I do the next training step, I use hessian to measure the information, save the result and, clean the gradients for the next step. It is not a problem for me not measuring the information in the batchnorm layers.
If I got your setup right, that will still be difficult without implementing the batchnorm operation. There is no option to disable the extension on the batchnorm parameters only, because backpack still needs to backpropagate the second-order information through the batchnorm layer to compute the information for the parameters of the earlier layers.
Here's a workaround that could work without having to code the batchnorm extension.
Say we start with the network
net1 = Sequential(
Linear(1,1),
Batchnorm1d(1),
Linear(1,1),
)
We can make a second network that does the same operation as net1
(if net1
is in eval
mode) using only Linear
layers,
net2 = Sequential(
Linear(1,1),
Linear(1,1),
Linear(1,1),
Linear(1,1),
)
To make them the same, we need to map the weights from net1
to net2
.
For the linear layers, we just copy the data
net2[0].weight.data = net1[0].weight.data
net2[0].bias.data = net1[0].bias.data
net2[3].weight.data = net1[2].weight.data
net2[3].bias.data = net1[2].bias.data
And we should be able to implement the batchnorm operation with 2 linear layers by remapping them as follows (needs a double check)
# Implement the normalization
# x -> (x - running_mean) / sqrt(running_var + eps) = (1 / sqrt(running_var + eps)) * x - running_mean / sqrt(running_var + eps)
bnlayer = net1[1]
net2[1].weight.data = 1/torch.sqrt(bnlayer._buffers["running_var"].data + bnlayer.eps)
net2[1].bias.data = bnlayer._buffers["running_mean"].data * net2[1].weight.data
net2[2].weight.data = bnlayer._parameters["weight"].data
net2[2].bias.data = bnlayer._parameters["bias"].data
Now we can extend net2
using backpack to compute kfac.
Instead of doing
extend(net1)
...
with Backpack("KFAC"):
loss(net1).backward()
extend(net2)
...
map_weights(net1, net2)
with Backpack("KFAC"):
loss(net2).backward()
inverse_map_grad_and_kfac(net2, net1)
where inverse_map_grad_and_kfac
would map the _.grad
and _.kfac
attributes of the parameters of net2
to the right parameters of net1
(Although operations on _.data
shouldn't be tracked by autodiff, maybe put all this in a torch.nograd()
block to make sure gradients don't get propagated from one network to the rest?)
THanks a lot for your thoughts! I will study this... :-)