functorch
functorch copied to clipboard
Feature request: fast way to approximate the diagonal of the hessian
Hi,
I want to get the second-order gradients for a nn.Module function. I use make_functional_with_buffers
to wrap it. But I fail to get the second-order gradients with RuntimeError. The code is as follows:
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
def compute_loss(params, buffers, x, t):
y = func(params, buffers, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, buffers, x, t)
grad_weights2 = grad(grad(compute_loss))(params, buffers, x, t)
I got the following error:
---> 17 grad_weights2 = grad(grad(compute_loss))(params, buffers, x, t)
File ~/programs/miniconda3/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1188, in grad.<locals>.wrapper(*args, **kwargs)
1186 @wraps(func)
1187 def wrapper(*args, **kwargs):
-> 1188 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1189 if has_aux:
1190 grad, (_, aux) = results
File ~/programs/miniconda3/lib/python3.9/site-packages/functorch/_src/eager_transforms.py:1068, in grad_and_value.<locals>.wrapper(*args, **kwargs)
1065 output, aux = output
1067 if not isinstance(output, torch.Tensor):
-> 1068 raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
1069 f'to return a Tensor, got {type(output)}')
1070 if output.dim() != 0:
1071 raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
1072 'to return a scalar Tensor, got tensor with '
1073 f'{output.dim()} dims. Maybe you wanted to '
1074 'use the vjp or jacrev APIs instead?')
RuntimeError: grad_and_value(f)(*args): Expected f(*args) to return a Tensor, got <class 'tuple'>
The functorch is directly installed from source.
When you say second order gradients -- do you want the hessian?
You can think of the model as a R^n -> R^1 function. The first order gradients have shape R^n. If you're looking for the hessian, then that's going to have shape R^n x R^n.
The way to compute the hessian w.r.t your parameters is to either do hessian(compute_loss)(params, buffers, x, t)
or jacrev(jacrev(compute_loss))(params, buffers, x, t)
. See this tutorial for more details: https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html#hessian-computation-with-functorch-hessian
@zou3519 Thanks for your help. As Hessian for a big network is extremely large. I just want to compute the diagonal of the Hessian. I wonder whether there is any API that I can use to calculate it. I think it is a widely used trick if directly getting the Hessian is too costy.
I think it is a widely used trick if directly getting the Hessian is too costy.
Is there a widely used trick to get the diagonal of the Hessian?
I know it is a little bit tricky. I don't whether this thread may give some help https://discuss.pytorch.org/t/second-order-derivatives-of-loss-function/71797/3 .
@fuzihaofzh Is the hessian with respect to the inputs of your model or the parameters of your model? There is a trick for the inputs but none exists (to my knowledge at least) for the parameters.
Hi @fuzihaofzh, I came across this repo (https://github.com/amirgholami/adahessian) and thought you might find it interesting as it uses a trick for the Hessian via Hutchinson's estimator.
Hi, @AlphaBetaGamma96 , the paper you give me is extremely useful for me. I am trying to reimplement this trick in my project. Thanks very much for your help. @zou3519 For this question, I think @AlphaBetaGamma96 's answer is a simple and useful way to give a calculation and I believe it would be quite useful if it could be implemented in functorch. I just found another tool that provide a similar solution:
https://docs.backpack.pt/en/master/extensions.html?highlight=per-sample%20diagonal%20Hessian#backpack.extensions.BatchDiagHessian
Hope it could help to somehow improve functorch. Thanks.
Thanks for the discussion! We'll take a look at the resources mentioned