functorch
functorch copied to clipboard
backward hooks aren't called when using vmap
Hi All,
I've been trying to get the forward activations and backward sensitivies for layers within my model. The forward_pre_hook
works as expected but backward_full_hook
fails to get called.
A minimal reproducible example of this error is below,
import torch
from torch import nn
from functorch import make_functional, vmap
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.fc1 = nn.Linear(2,32)
self.fc2 = nn.Linear(32,2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = model()
#forward-pre hook and backward full hook
def _save_input(module, input) -> None:
print("forward")
def _save_output(module, grad_input, grad_output) -> None:
print("backward")
for mod in net.modules():
mod.register_forward_pre_hook(_save_input)
mod.register_full_backward_hook(_save_output)
x = torch.randn(100, 2)
y = net(x) #works
y = vmap(net, in_dims=(0))(x) #fails
loss = torch.mean(y)
loss.backward()
When I'm using vmap
the output returns,
forward
forward
forward
When I'm not using vmap
the output returns,
forward
forward
forward
backward
backward
backward
For reference:
torch.__version__
1.12.0a0+git7c2103a
functorch.__version__
0.2.0a0+9d6ee76
I am not sure what is wrong here, but it sounds very similar to autograd.Function not working. @albanD do you have some context or reference on how nn Module hooks work?
The full_backward_hook
do rely on custom Function. So if custom Functions are not supported, then yes that won't work I'm afraid.
Cool, thanks for the analysis Alban. We are back to fixing the autograd.Function problem
@AlphaBetaGamma96 what are you using backward hooks for? Is it just per-sample-grad computation? If so we have other workarounds for that
I'm working on some gradient preconditioning techniques and it requires the forward activations and backward sensitivities (grad_output) of all nn.Module objects of a network. I am also calculating per-sample gradients and have been using vmap
with jacrev
/grad
but I think the fact I'm using some parts of my code with vmap
and other parts without is causing some issues.
One of the errors does reference autograd.Function
,
RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
@zou3519 Do you know if it's possible to compute the grad_output
of a layer via vmap
or is it only possible via hooks?
functorch doesn't see modules, it only sees PyTorch operations (e.g. add mul ...) so there isn't a way to do this right now
By the way, while the nn.Module
's full_backward_hook
will not work, the basic autograd hooks at the Tensor
level should work.
So in particular for your use case, if you just want the gradient for x, then you can get that by using a register_forward_pre_hook
where the hooks does:
def hook(mod, x):
x.register_hook(your_backward_hook)
It is more complex and will be hard to generalize to nn.Module that have multiple inputs (which is why full backward hook needs custom Function) but that might solve your problem?
Hi @albanD, thanks for the insight on how to use register_hook
! Just so I understand this correctly, I can register a hook on a Tensor and use my forward_pre_hook
and backward_full_hook
formulae and that'd return (for all samples) the input, and the backward sensitivity (grad_output
) for that Tensor?
I've managed to get around this issue by calculating grad_output
via the 'normal' pytorch way and then just use functorch to calculate other higher-order gradients. I need the input/grad_output term from just the output of my model which is then used in tandem with higher-order gradients of other functions. So I think my current method should be ok, although the implementation isn't the neatest!
I can register a hook on a Tensor and use my forward_pre_hook and backward_full_hook formulae and that'd return (for all samples) the input, and the backward sensitivity (grad_output) for that Tensor?
Yes. For simple nn.Module, that will allow you to "similate" the register_full_backward_hook
that doesn't work.
Hi @albanD, I had a quick look at register_hook
but it seems that the signature of that hook only takes the gradient of a Tensor rather than the grad_output
values that I need. The explicit signature is hook(grad) -> Tensor or None
, so I can't seem to get grad_output
. I need grad_output
as I want the backward sensitivities for all samples (and the same for the activatons but forward_pre_hook
works with vmap so that's not a problem).
I did notice that register_backward_hook
seems to work with functorch, although it is deprecated and I'm pretty sure it has incorrect behavior for modules with multiple inputs/outputs? All the layers of my network only have a single input and output Tensor so do you think it might be ok to use register_backward_hook
for the meantime? Or should I just wait for custom function support for backward hooks?
the gradient of a Tensor rather than the grad_output values that I need.
What you mean by that? the grad_output
is the gradient of the output. So if you register_hook
on the output of the Module, you will get the gradient of the output (also called grad_output
in the module hook's doc).
So, what I need is the gradient of the output of a module (which is what full_backward returns, although doesn't work with functorch atm) but after checking register_hook
it seems to only allow for a Tensor of the same shape as the Tensor it's registered on whereas full_backward hook returns a shape of [B, output_shape]
where B
is the batch size. I also need the batch size because I want to do some expectations on intermediate values.
I've just realized through writing this comment I may have noticed where I've made my mistake. I was attaching my hook to the weights of an nn.Module
which is incorrect. I assume I'd have to create a Tensor out
for all nn.Module
objects in my network and then manually register a hook on those out
variables? Then that would give me grad_output
you mentioned? Is that correct?
If I could ask one follow up question? grad_output
returns the gradient of the loss with respect to the output of a layer (as you said), however, does this apply if the loss contains derivatives itself? The reason I ask this is because I want to use grad_output
to do gradient decomposition (i.e. separate per-sample gradients into the outer product of 2 other Tensors, e.g.
https://arxiv.org/pdf/1510.01799.pdf). However, this seems to be applicable to only 1st order derivatives unless I'm mistaken? For example, the per-sample gradients for a nn.Linear
layer is just the outer product of grad_output
and its activations.
Thank you!
I assume I'd have to create a Tensor out for all nn.Module objects in my network and then manually register a hook on those out variables? Then that would give me grad_output you mentioned? Is that correct?
You don't have to do this, you can actually use a your_mod.register_forward_hook(special_hook)
. This special_hook
function will be called with all the outputs of your_mod's forward. So that hook can look like:
def special_hook(mod, inputs, outputs):
assert len(outputs) == 1
outputs[0].register_hook(my_backward_hook)
That's a neat trick! So this would basically do what full_backward_hook
does but for an nn.Module
object with only 1 output and works with functorch? That kinda looks like what I'm looking for, the only issue I have is how it'll work with higher-order gradients calculated with functorch for my loss? I'll give that hook a go and see if I can get something working! Thanks for the help! :)
Not sure how autograd hooks interact with higher order gradients done via functorch cc @zou3519 do you have example of that already?