functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Feature request: fast way to approximate the diagonal of the hessian

Open fuzihaofzh opened this issue 2 years ago • 8 comments

Hi, I want to get the second-order gradients for a nn.Module function. I use make_functional_with_buffers to wrap it. But I fail to get the second-order gradients with RuntimeError. The code is as follows:

import torch
import torch.nn as nn
from functorch import make_functional_with_buffers, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)

def compute_loss(params, buffers, x, t):
    y = func(params, buffers, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(params, buffers, x, t)
grad_weights2 = grad(grad(compute_loss))(params, buffers, x, t)

I got the following error:

---> 17 grad_weights2 = grad(grad(compute_loss))(params, buffers, x, t)

File ~/programs/miniconda3/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1188, in grad.<locals>.wrapper(*args, **kwargs)
   1186 @wraps(func)
   1187 def wrapper(*args, **kwargs):
-> 1188     results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
   1189     if has_aux:
   1190         grad, (_, aux) = results

File ~/programs/miniconda3/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1068, in grad_and_value.<locals>.wrapper(*args, **kwargs)
   1065     output, aux = output
   1067 if not isinstance(output, torch.Tensor):
-> 1068     raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
   1069                        f'to return a Tensor, got {type(output)}')
   1070 if output.dim() != 0:
   1071     raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
   1072                        'to return a scalar Tensor, got tensor with '
   1073                        f'{output.dim()} dims. Maybe you wanted to '
   1074                        'use the vjp or jacrev APIs instead?')

RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a Tensor, got <class 'tuple'>

The functorch is directly installed from source.

fuzihaofzh avatar Apr 07 '22 07:04 fuzihaofzh

When you say second order gradients -- do you want the hessian?

You can think of the model as a R^n -> R^1 function. The first order gradients have shape R^n. If you're looking for the hessian, then that's going to have shape R^n x R^n.

The way to compute the hessian w.r.t your parameters is to either do hessian(compute_loss)(params, buffers, x, t) or jacrev(jacrev(compute_loss))(params, buffers, x, t). See this tutorial for more details: https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html#hessian-computation-with-functorch-hessian

zou3519 avatar Apr 07 '22 14:04 zou3519

@zou3519 Thanks for your help. As Hessian for a big network is extremely large. I just want to compute the diagonal of the Hessian. I wonder whether there is any API that I can use to calculate it. I think it is a widely used trick if directly getting the Hessian is too costy.

fuzihaofzh avatar Apr 07 '22 15:04 fuzihaofzh

I think it is a widely used trick if directly getting the Hessian is too costy.

Is there a widely used trick to get the diagonal of the Hessian?

zou3519 avatar Apr 08 '22 15:04 zou3519

I know it is a little bit tricky. I don't whether this thread may give some help https://discuss.pytorch.org/t/second-order-derivatives-of-loss-function/71797/3 .

fuzihaofzh avatar Apr 08 '22 19:04 fuzihaofzh

@fuzihaofzh Is the hessian with respect to the inputs of your model or the parameters of your model? There is a trick for the inputs but none exists (to my knowledge at least) for the parameters.

AlphaBetaGamma96 avatar Apr 10 '22 15:04 AlphaBetaGamma96

Hi @fuzihaofzh, I came across this repo (https://github.com/amirgholami/adahessian) and thought you might find it interesting as it uses a trick for the Hessian via Hutchinson's estimator.

AlphaBetaGamma96 avatar Apr 21 '22 15:04 AlphaBetaGamma96

Hi, @AlphaBetaGamma96 , the paper you give me is extremely useful for me. I am trying to reimplement this trick in my project. Thanks very much for your help. @zou3519 For this question, I think @AlphaBetaGamma96 's answer is a simple and useful way to give a calculation and I believe it would be quite useful if it could be implemented in functorch. I just found another tool that provide a similar solution:

https://docs.backpack.pt/en/master/extensions.html?highlight=per-sample%20diagonal%20Hessian#backpack.extensions.BatchDiagHessian

Hope it could help to somehow improve functorch. Thanks.

fuzihaofzh avatar Apr 24 '22 09:04 fuzihaofzh

Thanks for the discussion! We'll take a look at the resources mentioned

zou3519 avatar Apr 25 '22 18:04 zou3519