Raise an error when PyTorch's activation checkpointing is used with Thunder-jitted model
🐛 Bug
Activation checkpointing is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during the backward pass. Thunder doesn't know how to deal with PyTorch's annotations and wrappers implementing this feature that leads to silent or loud problems, like https://github.com/Lightning-AI/lightning-thunder/issues/582.
When nn.Module passed to thunder.jit includes an instance of torch.distributed.algorithms._checkpoint.checkpoint_wrapper.CheckpointWrapper in its submodules Thunder should raise an error.
Is there any other indicator that the user used torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing?
cc @apaz-cli
I wonder how hard it would be to actually support basic CheckpointWrapper by having a "container symbol" and defining a grad transform that includes the forward in the backward...
- plan to do automatic checkpointing based on memory available (eventually)
- but this is also a hugely important thing for memory use, which is a big deal right now
- interim solution: make use of user's annotation? (This might actually be harder than just doing the full solution.)
- can we just push a scope on the stack and use that for handling what gets checkpointed? yes, but then we'll need to modify backward
- we would accept compromises, but let's wait for a bigger design