xla icon indicating copy to clipboard operation
xla copied to clipboard

Migrate PyTorch/XLA's gradient checkpointing to upstream one

Open JackCaoG opened this issue 1 year ago • 5 comments

🚀 Feature

Today PyTorch/XLA ask user to use its own version of the gradient checkpointing in https://github.com/pytorch/xla/blob/d1235858628417ed7abc0d61e6e9be50df3e1a87/torch_xla/utils/checkpoint.py#L145-L146 We should extend upstream's api instead of asking user to use our version.

Motivation

Upstream gradient checkpointing doesn't work because XLA's CSE(common subexpression elimation) pass will undo the gradient checkpointing. More details in https://github.com/pytorch/xla/issues/5766#issuecomment-1792913756 . As a result I copied the upstream checkpointing and add a optimization_barrier_ on inputs of backward recompute. This is bad because

  1. Our implementation get outdated very quickly
  2. It is difficult for user to discover our version of the gradient checkpointing

Pitch

I chatted with @soulitzer , there is a way to pass context manager to extend the gradient checkpointing behavior. @soulitzer even went ahead and wrote a draft

from torch.utils.weak import WeakTensorKeyDictionary
import contextlib
from torch.overrides import TorchFunctionMode
from torch.utils._pytree import tree_map_only
from torch.utils.checkpoint import checkpoint

class MarkInputsToRegion(TorchFunctionMode):
    def __init__(self, mark_fn):
        # tensor -> bool
        self.is_marked = WeakTensorKeyDictionary()
        self.mark_fn = mark_fn

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        def mark(x):
            self.mark_fn(x)
            self.is_marked[x] = True

        tree_map_only(torch.Tensor, mark, (args, kwargs))
        out = func(*args, **kwargs)
        tree_map_only(torch.Tensor, mark, out)
        return out

def context_fn():
    def mark_fn(x):
        print("input to region: ", x)
    return contextlib.nullcontext(), MarkInputsToRegion(mark_fn)

# Test a tensor that is closed over
y = torch.tensor([2.], requires_grad=True)
x = torch.tensor([1.], requires_grad=True)

def func(x):
  # the output of this mul or this clone should not be wrapped
  out = x * y
  return out.clone()

out = checkpoint(func, x, context_fn=context_fn, use_reentrant=False)
out.sum().backward()

What we should verified is that optimization_barrier_ should only be applied on the input of the backward recompute, not to all the backwards. I think we should take above code and play with it and verified if we can use this approach to extend gradient checkpointing.

Also on top of the optimization_barrier_ we also do some pytorch/xla rng seed state management in https://github.com/pytorch/xla/blob/d1235858628417ed7abc0d61e6e9be50df3e1a87/torch_xla/utils/checkpoint.py#L151-L155 We should think about how to handle this part in the extension as well.

cc @alanwaketan @jonb377 @albanD

JackCaoG avatar May 03 '24 19:05 JackCaoG

This sounds exciting!

alanwaketan avatar May 03 '24 19:05 alanwaketan

Awesome! One thing we may need to handle is autocast state with gradient checkpointing - the upstream restores state using device modules (e.g. torch.cuda or torch.cpu), and it fetches the device module using getattr(torch, device) which won't work out-of-the-box for us.

We can probably just extend the _get_device_module logic in the upstream to support torch_xla.

jonb377 avatar May 03 '24 20:05 jonb377

@jonb377 we should update the checkpoint code to use the new device-generic API for amp here https://github.com/pytorch/pytorch/pull/124479

albanD avatar May 03 '24 20:05 albanD

cc @tengyifei

miladm avatar Jul 26 '24 17:07 miladm

If possible, please check and fix, still facing #5766

steveepreston avatar Sep 21 '25 19:09 steveepreston