functorch icon indicating copy to clipboard operation
functorch copied to clipboard

backward hooks aren't called when using vmap

Open AlphaBetaGamma96 opened this issue 2 years ago • 16 comments

Hi All,

I've been trying to get the forward activations and backward sensitivies for layers within my model. The forward_pre_hook works as expected but backward_full_hook fails to get called.

A minimal reproducible example of this error is below,

import torch
from torch import nn
from functorch import make_functional, vmap

class model(nn.Module):

  def __init__(self):
    super(model, self).__init__()
    self.fc1 = nn.Linear(2,32)
    self.fc2 = nn.Linear(32,2)
    
  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    return x
    
net = model()

#forward-pre hook and backward full hook
def _save_input(module, input) -> None:
  print("forward")

def _save_output(module, grad_input, grad_output) -> None:
  print("backward")
  
for mod in net.modules():
  mod.register_forward_pre_hook(_save_input)
  mod.register_full_backward_hook(_save_output)
  
  
x = torch.randn(100, 2)

y = net(x)                       #works
y = vmap(net, in_dims=(0))(x)     #fails

loss = torch.mean(y)
loss.backward()

When I'm using vmap the output returns,

forward
forward
forward

When I'm not using vmap the output returns,

forward
forward
forward
backward
backward
backward

For reference: torch.__version__ 1.12.0a0+git7c2103a functorch.__version__ 0.2.0a0+9d6ee76

AlphaBetaGamma96 avatar Mar 30 '22 14:03 AlphaBetaGamma96

I am not sure what is wrong here, but it sounds very similar to autograd.Function not working. @albanD do you have some context or reference on how nn Module hooks work?

zou3519 avatar Mar 31 '22 13:03 zou3519

The full_backward_hook do rely on custom Function. So if custom Functions are not supported, then yes that won't work I'm afraid.

albanD avatar Mar 31 '22 14:03 albanD

Cool, thanks for the analysis Alban. We are back to fixing the autograd.Function problem

zou3519 avatar Mar 31 '22 15:03 zou3519

@AlphaBetaGamma96 what are you using backward hooks for? Is it just per-sample-grad computation? If so we have other workarounds for that

zou3519 avatar Mar 31 '22 18:03 zou3519

I'm working on some gradient preconditioning techniques and it requires the forward activations and backward sensitivities (grad_output) of all nn.Module objects of a network. I am also calculating per-sample gradients and have been using vmap with jacrev/grad but I think the fact I'm using some parts of my code with vmap and other parts without is causing some issues.

One of the errors does reference autograd.Function,

RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this

AlphaBetaGamma96 avatar Mar 31 '22 18:03 AlphaBetaGamma96

@zou3519 Do you know if it's possible to compute the grad_output of a layer via vmap or is it only possible via hooks?

AlphaBetaGamma96 avatar Mar 31 '22 20:03 AlphaBetaGamma96

functorch doesn't see modules, it only sees PyTorch operations (e.g. add mul ...) so there isn't a way to do this right now

zou3519 avatar Apr 01 '22 20:04 zou3519

By the way, while the nn.Module's full_backward_hook will not work, the basic autograd hooks at the Tensor level should work.

So in particular for your use case, if you just want the gradient for x, then you can get that by using a register_forward_pre_hook where the hooks does:

def hook(mod, x):
     x.register_hook(your_backward_hook)

It is more complex and will be hard to generalize to nn.Module that have multiple inputs (which is why full backward hook needs custom Function) but that might solve your problem?

albanD avatar Apr 01 '22 21:04 albanD

Hi @albanD, thanks for the insight on how to use register_hook! Just so I understand this correctly, I can register a hook on a Tensor and use my forward_pre_hook and backward_full_hook formulae and that'd return (for all samples) the input, and the backward sensitivity (grad_output) for that Tensor?

I've managed to get around this issue by calculating grad_output via the 'normal' pytorch way and then just use functorch to calculate other higher-order gradients. I need the input/grad_output term from just the output of my model which is then used in tandem with higher-order gradients of other functions. So I think my current method should be ok, although the implementation isn't the neatest!

AlphaBetaGamma96 avatar Apr 04 '22 17:04 AlphaBetaGamma96

I can register a hook on a Tensor and use my forward_pre_hook and backward_full_hook formulae and that'd return (for all samples) the input, and the backward sensitivity (grad_output) for that Tensor?

Yes. For simple nn.Module, that will allow you to "similate" the register_full_backward_hook that doesn't work.

albanD avatar Apr 05 '22 17:04 albanD

Hi @albanD, I had a quick look at register_hook but it seems that the signature of that hook only takes the gradient of a Tensor rather than the grad_output values that I need. The explicit signature is hook(grad) -> Tensor or None, so I can't seem to get grad_output. I need grad_output as I want the backward sensitivities for all samples (and the same for the activatons but forward_pre_hook works with vmap so that's not a problem).

I did notice that register_backward_hook seems to work with functorch, although it is deprecated and I'm pretty sure it has incorrect behavior for modules with multiple inputs/outputs? All the layers of my network only have a single input and output Tensor so do you think it might be ok to use register_backward_hook for the meantime? Or should I just wait for custom function support for backward hooks?

AlphaBetaGamma96 avatar May 01 '22 19:05 AlphaBetaGamma96

the gradient of a Tensor rather than the grad_output values that I need.

What you mean by that? the grad_output is the gradient of the output. So if you register_hook on the output of the Module, you will get the gradient of the output (also called grad_output in the module hook's doc).

albanD avatar May 02 '22 14:05 albanD

So, what I need is the gradient of the output of a module (which is what full_backward returns, although doesn't work with functorch atm) but after checking register_hook it seems to only allow for a Tensor of the same shape as the Tensor it's registered on whereas full_backward hook returns a shape of [B, output_shape] where B is the batch size. I also need the batch size because I want to do some expectations on intermediate values.

I've just realized through writing this comment I may have noticed where I've made my mistake. I was attaching my hook to the weights of an nn.Module which is incorrect. I assume I'd have to create a Tensor out for all nn.Module objects in my network and then manually register a hook on those out variables? Then that would give me grad_output you mentioned? Is that correct?

If I could ask one follow up question? grad_output returns the gradient of the loss with respect to the output of a layer (as you said), however, does this apply if the loss contains derivatives itself? The reason I ask this is because I want to use grad_output to do gradient decomposition (i.e. separate per-sample gradients into the outer product of 2 other Tensors, e.g. https://arxiv.org/pdf/1510.01799.pdf). However, this seems to be applicable to only 1st order derivatives unless I'm mistaken? For example, the per-sample gradients for a nn.Linear layer is just the outer product of grad_output and its activations.

Thank you!

AlphaBetaGamma96 avatar May 02 '22 14:05 AlphaBetaGamma96

I assume I'd have to create a Tensor out for all nn.Module objects in my network and then manually register a hook on those out variables? Then that would give me grad_output you mentioned? Is that correct?

You don't have to do this, you can actually use a your_mod.register_forward_hook(special_hook). This special_hook function will be called with all the outputs of your_mod's forward. So that hook can look like:

def special_hook(mod, inputs, outputs):
  assert len(outputs) == 1
  outputs[0].register_hook(my_backward_hook)

albanD avatar May 02 '22 14:05 albanD

That's a neat trick! So this would basically do what full_backward_hook does but for an nn.Module object with only 1 output and works with functorch? That kinda looks like what I'm looking for, the only issue I have is how it'll work with higher-order gradients calculated with functorch for my loss? I'll give that hook a go and see if I can get something working! Thanks for the help! :)

AlphaBetaGamma96 avatar May 02 '22 15:05 AlphaBetaGamma96

Not sure how autograd hooks interact with higher order gradients done via functorch cc @zou3519 do you have example of that already?

albanD avatar May 02 '22 15:05 albanD