Migrate PyTorch/XLA's gradient checkpointing to upstream one
🚀 Feature
Today PyTorch/XLA ask user to use its own version of the gradient checkpointing in https://github.com/pytorch/xla/blob/d1235858628417ed7abc0d61e6e9be50df3e1a87/torch_xla/utils/checkpoint.py#L145-L146 We should extend upstream's api instead of asking user to use our version.
Motivation
Upstream gradient checkpointing doesn't work because XLA's CSE(common subexpression elimation) pass will undo the gradient checkpointing. More details in https://github.com/pytorch/xla/issues/5766#issuecomment-1792913756 . As a result I copied the upstream checkpointing and add a optimization_barrier_ on inputs of backward recompute. This is bad because
- Our implementation get outdated very quickly
- It is difficult for user to discover our version of the gradient checkpointing
Pitch
I chatted with @soulitzer , there is a way to pass context manager to extend the gradient checkpointing behavior. @soulitzer even went ahead and wrote a draft
from torch.utils.weak import WeakTensorKeyDictionary
import contextlib
from torch.overrides import TorchFunctionMode
from torch.utils._pytree import tree_map_only
from torch.utils.checkpoint import checkpoint
class MarkInputsToRegion(TorchFunctionMode):
def __init__(self, mark_fn):
# tensor -> bool
self.is_marked = WeakTensorKeyDictionary()
self.mark_fn = mark_fn
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def mark(x):
self.mark_fn(x)
self.is_marked[x] = True
tree_map_only(torch.Tensor, mark, (args, kwargs))
out = func(*args, **kwargs)
tree_map_only(torch.Tensor, mark, out)
return out
def context_fn():
def mark_fn(x):
print("input to region: ", x)
return contextlib.nullcontext(), MarkInputsToRegion(mark_fn)
# Test a tensor that is closed over
y = torch.tensor([2.], requires_grad=True)
x = torch.tensor([1.], requires_grad=True)
def func(x):
# the output of this mul or this clone should not be wrapped
out = x * y
return out.clone()
out = checkpoint(func, x, context_fn=context_fn, use_reentrant=False)
out.sum().backward()
What we should verified is that optimization_barrier_ should only be applied on the input of the backward recompute, not to all the backwards. I think we should take above code and play with it and verified if we can use this approach to extend gradient checkpointing.
Also on top of the optimization_barrier_ we also do some pytorch/xla rng seed state management in https://github.com/pytorch/xla/blob/d1235858628417ed7abc0d61e6e9be50df3e1a87/torch_xla/utils/checkpoint.py#L151-L155
We should think about how to handle this part in the extension as well.
cc @alanwaketan @jonb377 @albanD
This sounds exciting!
Awesome! One thing we may need to handle is autocast state with gradient checkpointing - the upstream restores state using device modules (e.g. torch.cuda or torch.cpu), and it fetches the device module using getattr(torch, device) which won't work out-of-the-box for us.
We can probably just extend the _get_device_module logic in the upstream to support torch_xla.
@jonb377 we should update the checkpoint code to use the new device-generic API for amp here https://github.com/pytorch/pytorch/pull/124479
cc @tengyifei
If possible, please check and fix, still facing #5766