functorch icon indicating copy to clipboard operation
functorch copied to clipboard

function(Jacobian)-dot-vector and vector-Jacobian-vector function

Open veya2ztn opened this issue 2 years ago • 0 comments

Hi,

I'd like to use functorh to realize following loss:

Question demonstrate

assume the

  • the dimension of output tensor is $O$ and we will use $y^\gamma$ mark each element.
  • the dimension of input tensor (primal) is $I$ and we use $x_\alpha$ mark each element.
  • we have a pytorch model $f$ with parameter marked as $W$ to map the input to output $f:\vec{x}(R^I) \rightarrow \vec{y}(R^O)$

there exists the Jacobian matrix $(O\times I)$ marked $J_\alpha^\gamma=\frac{\partial y^{\gamma}}{\partial x_\alpha}$

I am want to calculate two term

$$ L1=\sum_\gamma(\sum_\alpha J_\alpha^{\gamma}-1)^2 $$

$$ L2 =\sum_\gamma [\sum_\alpha (J_\alpha^{\gamma})^2-1]^2 $$

as well as there gradient of $W$, $\frac{\partial L1}{\partial W}$ and $\frac{\partial L2}{\partial W}$ for the gradient decent update.

This is easier to realize with the help of functorch , I post a toy example below

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import torch
import torch.nn.functional as F
import functorch
from functorch import jacrev,jacfwd
from functorch import make_functional, vmap, grad
B=200
I=100
O=300
class MyModel(torch.nn.Module):
    def __init__(self, in_chan, out_chan):
        super().__init__()
        self.backbone = torch.nn.Linear(in_chan, out_chan,bias=False)
    def forward(self,x):
        return self.backbone(x)**2
model= MyModel(I, O).cuda()
x    = torch.randn(B, I).cuda()
cotangents = torch.ones(B,I).cuda()
func_model, params = make_functional(model)

### ---> to calculate the dL1/dw term
def Normlization_Term_1(params,x):
        return ((functorch.jvp(lambda x:func_model(params,x), (x,), (cotangents,)
            )[1]-1)**2).mean()
Derivation_Term_1 = jacrev(Normlization_Term_1, argnums=0)(params, x)

### ---> to calculate the dL2/dw term
Normlization_Term_2= lambda params,x:(
    (vmap(jacrev(func_model, argnums=1), (None, 0))(params, x)**2).sum(-1)-1
    )**2
Derivation_Term_2 = jacrev(Normlization_Term_2, argnums=0)(params, x)

Problem

The idea is to calculate:

  • $\sum_\alpha J_\alpha^{\gamma}$ this term is easy to realize by the functorch.jvp and torch.autograd.functional.jvp by setting the cotangents as all-one tensor torch.ones(B,I). If we do the summation $\sum_\gamma$ in the wrapped function and pass it to calculate the Jacobian of model's parameter $W$, it run fast and cost small memory.
  • However, when calculate the next term $\sum_\alpha (J_\alpha^{\gamma})^2$ . There is no jvp function here and I have to create the full Jacobian of primal followed with a .sum() function to obtain result. In such a case, we will face OOM problem. My machine is A100-80G.

I suppose it is because we have to access the full Jacobian matrix $J_\alpha^{\gamma}$ in the second case which is too large to store during computation.

The OOM issue is also reported by https://github.com/pytorch/functorch/issues/636#issue-1185946292 and (possibly) solved by the recent update with chunks option in https://github.com/pytorch/functorch/issues/680#issue-1197453691

My ideas are

  • Can we build a function in native that produce the F(Jacobian)-dot-vector output vector $f(J)\cdot \vec{n}\rightarrow \vec{v}$

    if the $f:x\rightarrow x$ , then it is the functorch.jvp $J\cdot \vec{n}\rightarrow \vec{v}$

    if the $f: x\rightarrow x^2$, the it is the second term in my example. But this time, since it doesn't to access the full Jacobian, it becomes more memory efficient.

  • some usages of Jacobian function would only require

    • Jacobian-dot-vector produce a vector, covered by the functorch.jvp
    • vector-dot-Jacobian produce a vector, covered by the functorch.vjp
    • vecotr-dot-jacobian-dot-vector produce a scalar, need to be realized by the jvp or vjp

    When do gradient calculation on those output, the memory usage to store intermediate tensor is around D of vector x N of parameters. Is that possible to realize a native vecotr-dot-jacobian-dot-vector without access those large intermediate and become memory efficient?


I check the source code in jvp , it directly use the dual mode of pytorch-fwdad and return the jvp term directly from _unpack_dual , so I am afraid this problem may beyond the scope in functorch pipline.

Anyway, I look forward your discussion.

veya2ztn avatar Oct 28 '22 07:10 veya2ztn