functorch
functorch copied to clipboard
function(Jacobian)-dot-vector and vector-Jacobian-vector function
Hi,
I'd like to use functorh
to realize following loss:
Question demonstrate
assume the
- the dimension of output tensor is $O$ and we will use $y^\gamma$ mark each element.
- the dimension of input tensor (
primal
) is $I$ and we use $x_\alpha$ mark each element. - we have a pytorch model $f$ with parameter marked as $W$ to map the input to output $f:\vec{x}(R^I) \rightarrow \vec{y}(R^O)$
there exists the Jacobian matrix $(O\times I)$ marked $J_\alpha^\gamma=\frac{\partial y^{\gamma}}{\partial x_\alpha}$
I am want to calculate two term
$$ L1=\sum_\gamma(\sum_\alpha J_\alpha^{\gamma}-1)^2 $$
$$ L2 =\sum_\gamma [\sum_\alpha (J_\alpha^{\gamma})^2-1]^2 $$
as well as there gradient of $W$, $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.
This is easier to realize with the help of functorch
, I post a toy example below
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
import torch.nn.functional as F
import functorch
from functorch import jacrev,jacfwd
from functorch import make_functional, vmap, grad
B=200
I=100
O=300
class MyModel(torch.nn.Module):
def __init__(self, in_chan, out_chan):
super().__init__()
self.backbone = torch.nn.Linear(in_chan, out_chan,bias=False)
def forward(self,x):
return self.backbone(x)**2
model= MyModel(I, O).cuda()
x = torch.randn(B, I).cuda()
cotangents = torch.ones(B,I).cuda()
func_model, params = make_functional(model)
### ---> to calculate the dL1/dw term
def Normlization_Term_1(params,x):
return ((functorch.jvp(lambda x:func_model(params,x), (x,), (cotangents,)
)[1]-1)**2).mean()
Derivation_Term_1 = jacrev(Normlization_Term_1, argnums=0)(params, x)
### ---> to calculate the dL2/dw term
Normlization_Term_2= lambda params,x:(
(vmap(jacrev(func_model, argnums=1), (None, 0))(params, x)**2).sum(-1)-1
)**2
Derivation_Term_2 = jacrev(Normlization_Term_2, argnums=0)(params, x)
Problem
The idea is to calculate:
- $\sum_\alpha J_\alpha^{\gamma}$ this term is easy to realize by the
functorch.jvp
andtorch.autograd.functional.jvp
by setting thecotangents
as all-one tensortorch.ones(B,I)
. If we do the summation $\sum_\gamma$ in the wrapped function and pass it to calculate the Jacobian of model's parameter $W$, it run fast and cost small memory. - However, when calculate the next term $\sum_\alpha (J_\alpha^{\gamma})^2$ . There is no
jvp
function here and I have to create the full Jacobian ofprimal
followed with a.sum()
function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.
I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.
The OOM issue is also reported by https://github.com/pytorch/functorch/issues/636#issue-1185946292 and (possibly) solved by the recent update with chunks
option in https://github.com/pytorch/functorch/issues/680#issue-1197453691
My ideas are
-
Can we build a function in native that produce the
F(Jacobian)-dot-vector
output vector $f(J)\cdot \vec{n}\rightarrow \vec{v}$if the $f:x\rightarrow x$ , then it is the
functorch.jvp
$J\cdot \vec{n}\rightarrow \vec{v}$if the $f: x\rightarrow x^2$, the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.
-
some usages of Jacobian function would only require
-
Jacobian-dot-vector
produce a vector, covered by thefunctorch.jvp
-
vector-dot-Jacobian
produce a vector, covered by thefunctorch.vjp
-
vecotr-dot-jacobian-dot-vector
produce a scalar, need to be realized by thejvp
orvjp
When do gradient calculation on those output, the memory usage to store intermediate tensor is around
D of vector
xN of parameters
. Is that possible to realize a nativevecotr-dot-jacobian-dot-vector
without access those large intermediate and become memory efficient? -
I check the source code in jvp
, it directly use the dual
mode of pytorch-fwdad
and return the jvp
term directly from _unpack_dual
, so I am afraid this problem may beyond the scope in functorch
pipline.
Anyway, I look forward your discussion.