functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Get wrong jacobian from copyslice operation

Open KagamineLenOffical opened this issue 3 years ago • 8 comments

Hi, Thanks for your great work, I'm working with some minimize algorithm accleration using this code and find a problem. Here is a sample example to reproduce the conditon:

import torch
from functorch import jacfwd
class FunctionWrapper(object):
    def __init__(self, fun):  # note that function can be a lambda expression
        self._fun = fun
        self.fevals = 0
        self.cur_x = None
        self.cur_y = None
        self.input_segment = None
        self.constant_dict = None

    def __call__(self, v, **kwargs):
        self.fevals += 1
        # self.cur_x = v.view(-1).requires_grad_()
        self.cur_x = v.view(-1)
        if self.input_segment is not None:
            x = []
            for i, _ in enumerate(self.input_segment):
                if i > 0:
                    x.append(v[self.input_segment[i - 1]:self.input_segment[i]])
            self.cur_y = self._fun(*x, **kwargs)
        else:
            self.cur_y = self._fun(self.cur_x, **kwargs)
        return self.cur_y

    def input_constructor(self, *args):  # if has time, convert to kwargs input
        l = []
        self.input_segment = [0]
        cur = 0
        for v in args:
            nv = v.view(-1)
            l.append(nv)
            cur += nv.size()[0]
            self.input_segment.append(cur)
        x = torch.concat(l).detach().requires_grad_()
        return x

if __name__ == '__main__':
    xx_ = torch.tensor([4.,5.,6.])
    yy_ = torch.tensor([7.,8.])
    def func(*args):
        #x_ = torch.tensor([1.,2.,3.,4.])
        xx_[:2] = args[0]
        y = args[1]
        return torch.vstack([(xx_**2).sum(),(y**3).sum()])
    funcc = dogleg.FunctionWrapper(func)
    xx = funcc.input_constructor(xx_[:2],yy_)
    print(torch.autograd.functional.jacobian(funcc,xx))
    print(jacfwd(funcc)(xx))

The functionWrapper is for counting function calls and spliting/merging the input value. The output will be:

tensor([[[  8.,  10.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]])
tensor([[[  0.,   0.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]], grad_fn=<ViewBackward0>)

You can see that the result is different from functional.jacobian and get zero for all the w.r.t.s that slicecopy into input xx_. But if we move xx_ into the function, we can get a right result.

if __name__ == '__main__':
    xx_ = torch.tensor([4.,5.,6.],requires_grad=True)
    yy_ = torch.tensor([7.,8.],requires_grad=True)
    def func(*args):
        x_ = torch.tensor([1.,2.,3.,4.])
        x_[:2] = args[0]
        y = args[1]
        return torch.vstack([(x_**2).sum(),(y**3).sum()])
    funcc = FunctionWrapper(func)
    xx = funcc.input_constructor(xx_[:2],yy_)
    print(torch.autograd.functional.jacobian(funcc,xx))
    print(jacfwd(funcc)(xx))

With the output:

tensor([[[  8.,  10.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]])
tensor([[[  8.,  10.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]], grad_fn=<ViewBackward0>)

Is there any misusage of this function? I can fix this problem by firstly split the input into two inputs, but i'm looking for a more general solution.

KagamineLenOffical avatar Aug 23 '22 09:08 KagamineLenOffical

@KagamineLenOffical in the first case, does it work if you pass xx_ to the function?

xx_ = torch.tensor([4.,5.,6.])
yy_ = torch.tensor([7.,8.])
def func(xx_, *args):
    #x_ = torch.tensor([1.,2.,3.,4.])
    xx_[:2] = args[0]
    y = args[1]
    return torch.vstack([(xx_**2).sum(),(y**3).sum()])

We have a limitation (that we have not documented) that one is not allowed to perform in-place operations on Tensors that have not been directly passed or were created inside of func. This is supposed to raise an error, but it looks like it didn't in this case.

zou3519 avatar Aug 23 '22 15:08 zou3519

@zou3519 seems it still won't work if I pass xx_ to the function. Here is the code:

import torch
from functorch import jacfwd
import dogleg
class FunctionWrapper(object):
    def __init__(self, fun, p):  # note that function can be a lambda expression
        ...
        self.p = p
    def __call__(self, v, **kwargs):
            ...
            # here I pass the xx_ through self.p to the function
            self.cur_y = self._fun(self.p, *x, **kwargs)
            ...
        return self.cur_y

if __name__ == '__main__':
    x_replace = torch.tensor([4.,5.,6.])
    yy_ = torch.tensor([7.,8.])
    def func(xx_, *args):
        #x_ = torch.tensor([1.,2.,3.,4.])
        xx_[:2] = args[0]
        y = args[1]
        return torch.vstack([(xx_**2).sum(),(y**3).sum()])

    # save the x_replace into functionWrapper
    funcc = FunctionWrapper(func, x_replace)
    xx = funcc.input_constructor(x_replace[:2], yy_)
    print(torch.autograd.functional.jacobian(funcc,xx))
    print(jacfwd(funcc)(xx))

And the output is the same with the previous case.

KagamineLenOffical avatar Aug 24 '22 03:08 KagamineLenOffical

To clarify, in order to support mutation of xx_, the function being transformed (funcc) needs to accept xx_ as an argument. There are two different things that we can try here to resolve this:

  1. Is it possible to change FunctionWrapper. __call__ to accept xx_?
  2. Instead of doing a mutation on a captured variable (xx_), can you do something like the following:
def func(*args):
  xx_ = xx_.clone()
  xx_[:2] = args[0]
  y = args[1]
  return torch.vstack([(xx_**2).sum(),(y**3).sum()])

For (2), the thing I am curious about, is if it matters that you're mutating xx_ in-place

zou3519 avatar Aug 24 '22 13:08 zou3519

What does the mutation here refer to?

In my understanding, for the first case, since jacfwd can to which args to get the jacobian w.r.t.(not like functional.jacobian), the follow code will be this case:

class FunctionWrapper(object):
# pass xx_ to  __call__
    def __call__(self, p, v, **kwargs):
        self.fevals += 1
        # self.cur_x = v.view(-1).requires_grad_()
        self.cur_x = v.view(-1)
        if self.input_segment is not None:
            x = []
            for i, _ in enumerate(self.input_segment):
                if i > 0:
                    x.append(v[self.input_segment[i - 1]:self.input_segment[i]])
# pass to _fun
            self.cur_y = self._fun(p, *x, **kwargs)
        else:
            self.cur_y = self._fun(p, self.cur_x, **kwargs)
        return self.cur_y

if __name__ == '__main__':
    x_replace = torch.tensor([4.,5.,6.])
    yy_ = torch.tensor([7.,8.])
    def func(xx_, *args):
        xx_[:2] = args[0]
        y = args[1]
        return torch.vstack([(xx_**2).sum(),(y**3).sum()])

    # save the x_replace into functionWrapper
    funcc = FunctionWrapper(func)
    xx = funcc.input_constructor(x_replace[:2], yy_)
    print(jacfwd(funcc,argnums=1)(x_replace, xx))

And still get the wrong jacobian. The following cases will get it right:

    def func(xx_, *args):
        #x_ = torch.tensor([1.,2.,3.,4.])
        xx_ = xx_.detach() # or clone()
        xx_[:2] = args[0]
        y = args[1]
        return torch.vstack([(xx_**2).sum(),(y**3).sum()])

And like the above case, (2) that you have mentioned will also get the correct result and solve my problem.

KagamineLenOffical avatar Aug 24 '22 14:08 KagamineLenOffical

What does the mutation here refer to?

By mutation, I'm referring to xx_[:2] = args[0]. This line mutates the tensor xx_ in-place by updating the elements.

For the first case you provided -- would it be possible to get a runnable repro script so that we could dig into it? It is a bit difficult to tell what is going on from just talking about the snippet

zou3519 avatar Aug 24 '22 15:08 zou3519

Thanks for you reply! Here is a sample version to reproduce this situation:

import torch
from functorch import jacfwd

if __name__ == '__main__':
    x_replace = torch.tensor([4.,5.,6.])
    input_x = torch.tensor([10., 11.])
    yy_ = torch.tensor([7.,8.])
    def func(xx_, *args):
        #x_ = torch.tensor([1.,2.,3.,4.])
        xx_[:2] = args[0]
        return (xx_**2).sum()

    # print(torch.autograd.functional.jacobian(func,input_x))
    print(jacfwd(func, argnums=1)(x_replace, input_x))

And this is the full code for the first case:

import torch
from functorch import jacfwd
class FunctionWrapper(object):
    def __init__(self, fun):  # note that function can be a lambda expression
        self._fun = fun
        self.fevals = 0
        self.cur_x = None
        self.cur_y = None
        self.input_segment = None
        self.constant_dict = None
    def __call__(self, p, v, **kwargs):
        self.fevals += 1
        # self.cur_x = v.view(-1).requires_grad_()
        self.cur_x = v.view(-1)
        if self.input_segment is not None:
            x = []
            for i, _ in enumerate(self.input_segment):
                if i > 0:
                    x.append(v[self.input_segment[i - 1]:self.input_segment[i]])
            self.cur_y = self._fun(p, *x, **kwargs)
        else:
            self.cur_y = self._fun(p, self.cur_x, **kwargs)
        return self.cur_y

    def input_constructor(self, *args):  # if has time, convert to kwargs input
        l = []
        self.input_segment = [0]
        cur = 0
        for v in args:
            nv = v.view(-1)
            l.append(nv)
            cur += nv.size()[0]
            self.input_segment.append(cur)
        x = torch.concat(l).detach().requires_grad_()
        return x

if __name__ == '__main__':
    x_replace = torch.tensor([4.,5.,6.])
    input_x = torch.tensor([10., 11.])
    yy_ = torch.tensor([7.,8.])
    def func(xx_, *args):
        #x_ = torch.tensor([1.,2.,3.,4.])
        xx_[:2] = args[0]
        y = args[1]
        return torch.vstack([(xx_**2).sum(),(y**3).sum()])

    funcc = FunctionWrapper(func)
    xx = funcc.input_constructor(input_x, yy_)
    # print(torch.autograd.functional.jacobian(funcc,xx))
    print(jacfwd(funcc,argnums=1)(x_replace, xx))

KagamineLenOffical avatar Aug 25 '22 02:08 KagamineLenOffical

My apologies for the delayed reply. Thank you for the repro, it clearly demonstrates the problem. We'll dig into it more

zou3519 avatar Sep 06 '22 14:09 zou3519

After digging into this a little bit, here's what's happening and how we can fix it: (1) In BOTH the cases where xx_ is captured AND passed in, it's not wrapped as a GradTensor (this is probably a bug that we should fix, see end) (2) In both cases, we call slice(xx_) for xx_[:2]. Since the AD interpreter is always called (and the front layer key is mode-style), we'll call the AD interpreter even though xx_ is not a GradWrapper (3) The AD interpreter wraps all outputs as GradTensors so slice(xx_) is a GradTensor and during the copy_ we'll write to its grad function which will update xx_ (the unwrapped)'s grad_fn (4) When we use xx_ later, we're using the unwrapped version, not the GradWrapper one (since xx_ doesn't know that the GradWrapper was created around it). So, the output doesn't think that xx_ participates in this version of autograd

Solutions: (1) Right now jacrev deals with its argnum argument by making a version of the function that captures the free variables. Instead we should add argnum to vjp and then do what grad does for its argnum (wrap all the arguments as GradTensors but only set _requires_grad on a subset of them). Then, we should make jacrev use in_dims and argnums properly. I'm not certain what the behavior here will be but we should double check this works

(2) We should either (a) make the AD interpreter not run if the GradWrappers key is not set or (b) make it not wrap if none of the inputs are GradWrappers. Even if (1) works, users should still get errors when writing the original code

samdow avatar Sep 12 '22 15:09 samdow