lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Raise an error when PyTorch's activation checkpointing is used with Thunder-jitted model

Open IvanYashchuk opened this issue 1 year ago • 2 comments

🐛 Bug

Activation checkpointing is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during the backward pass. Thunder doesn't know how to deal with PyTorch's annotations and wrappers implementing this feature that leads to silent or loud problems, like https://github.com/Lightning-AI/lightning-thunder/issues/582.

When nn.Module passed to thunder.jit includes an instance of torch.distributed.algorithms._checkpoint.checkpoint_wrapper.CheckpointWrapper in its submodules Thunder should raise an error.

Is there any other indicator that the user used torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing?

cc @apaz-cli

IvanYashchuk avatar Jul 15 '24 14:07 IvanYashchuk

I wonder how hard it would be to actually support basic CheckpointWrapper by having a "container symbol" and defining a grad transform that includes the forward in the backward...

t-vi avatar Jul 15 '24 14:07 t-vi

  • plan to do automatic checkpointing based on memory available (eventually)
  • but this is also a hugely important thing for memory use, which is a big deal right now
  • interim solution: make use of user's annotation? (This might actually be harder than just doing the full solution.)
  • can we just push a scope on the stack and use that for handling what gets checkpointed? yes, but then we'll need to modify backward
  • we would accept compromises, but let's wait for a bigger design

tfogal avatar Jul 15 '24 15:07 tfogal