functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Check that checkpointing works with functorch grad transforms

Open zou3519 opened this issue 3 years ago • 9 comments

See title. From discussion with @albanD

zou3519 avatar Aug 04 '22 15:08 zou3519

FYI @rohan-varma

albanD avatar Aug 04 '22 17:08 albanD

The answer is no. There are a number of things that need to be resolved.

  1. torch.autograd.grad doesn't compose with torch.utils.checkpoint
  2. autograd.Function doesn't work with functorch.grad. I guess torch.utils.checkpoint builds an autograd.Function somewhere?
  3. After we get past the above two, I think there will be problems. It depends on when the saved variable hook gets installed. In functorch grad(grad(g)), do we install a saved variable hook in both graphs that are being created, or just one of them? If the former then we are good, if the latter (which I suspect would happen) then we've got a problem.

Here's a script that demonstrates (1) and (2):

import torch
from torch.utils.checkpoint import checkpoint
from functorch import grad, make_fx

def f(x):
  y = x.sin()
  z = y.sin()
  return z

def g(x):
  return checkpoint(f, x, use_reentrant=True)

x = torch.tensor(5.)


def grad_g(x):
  x.requires_grad_()
  y = g(x)
  gx, = torch.autograd.grad(y, x)
  x.requires_grad_(False)
  return gx

# Problem 1
gm = make_fx(grad_f)(x)
print(gm)    

# Problem 2
grad(g)(x)

zou3519 avatar Aug 17 '22 20:08 zou3519

Ho yes the default one will never work. You should pass use_reentrant=False to get the version that doesn't use custom Function.

albanD avatar Aug 17 '22 20:08 albanD

Ahh thanks for the clarification. It looks like functorch is silently incorrect on nested grad and checkpoint with use_reentrant=False

import torch
from torch.utils.checkpoint import checkpoint
from functorch import grad, make_fx

def f(x):
  y = x.sin()
  z = y.sin()
  return z

def g(x):
  return checkpoint(f, x, use_reentrant=False)

x = torch.tensor(5.)

x.requires_grad_()
y = g(x)
gx, = torch.autograd.grad(y, x, create_graph=True)
ggx, = torch.autograd.grad(gx, x)

result = grad(grad(g))(x)
# Result is 0, which is wrong
print(result, ggx)

zou3519 avatar Aug 17 '22 20:08 zou3519

A smaller repro is grad(g)(x), which returns a tensor that does not require grad (but it should, since x requires grad).

zou3519 avatar Sep 08 '22 13:09 zou3519

There are two problems here.

  1. Removing this detach call makes the above example work. I suspect we need a shallow_detach operation or something that only detaches the top layer of autograd; this detach call is detaching ALL layers of autograd. Similar to something @ezyang and I discussed this a bit last week with shallow_copy_and_detach emitting a detach() call when it really should mean "detach a single layer".
  2. Checkpointing isn't composable with functorch transforms. Concretely the problem is that anything that sets TLS during the forward pass needs to be special cased here so that the forward pass during checkpointing is re-run in the same way it was run originally. I'm thinking modes may fall into this category and also might be broken under checkpointing, cc @samdow.

(1) Is easy to fix, (2) is difficult and may fall into the category of "higher order operators". Because this is difficult to fix generally, my plan is to figure out how to disallow checkpointing with functorch for now.

zou3519 avatar Sep 12 '22 14:09 zou3519

We have a general way of capturing all TLS in C++ to move it to other threads, we just have to bind it in Python

ezyang avatar Sep 12 '22 18:09 ezyang

So, turns out (1) (whatever is going on with the detach() call) is subtle and potentially difficult to resolve. TL;DR is that checkpointing does some things that "are not composite compliant".

Analysis is in https://docs.google.com/document/d/1xVRFtItMkIqs9eqMj2jqv-SQNoHZh8m-V71oSPsdKmc/edit

zou3519 avatar Sep 12 '22 20:09 zou3519

Issue mitigated by https://github.com/pytorch/pytorch/pull/85829, but still needs proper fix.

zou3519 avatar Sep 29 '22 17:09 zou3519