functorch
functorch copied to clipboard
Check that checkpointing works with functorch grad transforms
See title. From discussion with @albanD
FYI @rohan-varma
The answer is no. There are a number of things that need to be resolved.
- torch.autograd.grad doesn't compose with torch.utils.checkpoint
- autograd.Function doesn't work with functorch.grad. I guess torch.utils.checkpoint builds an autograd.Function somewhere?
- After we get past the above two, I think there will be problems. It depends on when the saved variable hook gets installed. In functorch
grad(grad(g)), do we install a saved variable hook in both graphs that are being created, or just one of them? If the former then we are good, if the latter (which I suspect would happen) then we've got a problem.
Here's a script that demonstrates (1) and (2):
import torch
from torch.utils.checkpoint import checkpoint
from functorch import grad, make_fx
def f(x):
y = x.sin()
z = y.sin()
return z
def g(x):
return checkpoint(f, x, use_reentrant=True)
x = torch.tensor(5.)
def grad_g(x):
x.requires_grad_()
y = g(x)
gx, = torch.autograd.grad(y, x)
x.requires_grad_(False)
return gx
# Problem 1
gm = make_fx(grad_f)(x)
print(gm)
# Problem 2
grad(g)(x)
Ho yes the default one will never work.
You should pass use_reentrant=False to get the version that doesn't use custom Function.
Ahh thanks for the clarification. It looks like functorch is silently incorrect on nested grad and checkpoint with use_reentrant=False
import torch
from torch.utils.checkpoint import checkpoint
from functorch import grad, make_fx
def f(x):
y = x.sin()
z = y.sin()
return z
def g(x):
return checkpoint(f, x, use_reentrant=False)
x = torch.tensor(5.)
x.requires_grad_()
y = g(x)
gx, = torch.autograd.grad(y, x, create_graph=True)
ggx, = torch.autograd.grad(gx, x)
result = grad(grad(g))(x)
# Result is 0, which is wrong
print(result, ggx)
A smaller repro is grad(g)(x), which returns a tensor that does not require grad (but it should, since x requires grad).
There are two problems here.
- Removing this detach call makes the above example work. I suspect we need a
shallow_detachoperation or something that only detaches the top layer of autograd; this detach call is detaching ALL layers of autograd. Similar to something @ezyang and I discussed this a bit last week with shallow_copy_and_detach emitting a detach() call when it really should mean "detach a single layer". - Checkpointing isn't composable with functorch transforms. Concretely the problem is that anything that sets TLS during the forward pass needs to be special cased here so that the forward pass during checkpointing is re-run in the same way it was run originally. I'm thinking modes may fall into this category and also might be broken under checkpointing, cc @samdow.
(1) Is easy to fix, (2) is difficult and may fall into the category of "higher order operators". Because this is difficult to fix generally, my plan is to figure out how to disallow checkpointing with functorch for now.
We have a general way of capturing all TLS in C++ to move it to other threads, we just have to bind it in Python
So, turns out (1) (whatever is going on with the detach() call) is subtle and potentially difficult to resolve. TL;DR is that checkpointing does some things that "are not composite compliant".
Analysis is in https://docs.google.com/document/d/1xVRFtItMkIqs9eqMj2jqv-SQNoHZh8m-V71oSPsdKmc/edit
Issue mitigated by https://github.com/pytorch/pytorch/pull/85829, but still needs proper fix.