pytorch-image-models
pytorch-image-models copied to clipboard
[FEATURE] Gradient checkpointing in `forward_intermediates()`
Is your feature request related to a problem? Please describe.
I rely on the forward_intermediates() API for object detection models, and I'm experimenting with ViT-g and would like to try gradient checkpointing.
Describe the solution you'd like
In VisionTransformer.forward_features() we have:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
I'm thinking something like this could work in VisionTransformer.forward_intermediates():
for i, blk in enumerate(blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_module(blk, x)
else:
x = blk(x)
I called this checkpoint_module() but I think we could just use checkpoint_seq() directly, based on the code? Either way, is this as simple as I think it would be, or am I missing something? I haven't used gradient checkpointing a lot so I'm not entirely sure.
I'm happy to submit a PR for a few models if it's as simple as calling checkpoint_seq() in forward_intermediates() as I've outlined above. I'm not sure how many models use this API and/or self.grad_checkpointing, and whether you want this to be supported in all of them.
I just noticed for ConvNeXt the gradient checkpointing is done within a ConvNeXt stage, which means it would work as is for forward_intermediates(). So maybe this feature request is specific to VisionTransformer (or other models whose gradient checkpointing won't work within forward_intermediates()).
Also, shouldn't this be called activation checkpointing not gradient checkpointing? Just want to make sure I'm not misunderstanding the implementation / goal here. I'm guessing the name comes from the HuggingFace trainer flag, but is a bit of a misnomer?
@collinmccarthy you are correct on all counts, I didn't explicitly support this when I added foward_intermediates() as I was focused on getting it working / integrated and then didn't revisit.
Stage based ones that needed to push the logic into the stages should still work.
Activation checkpointing makes more sense as the name / description of what's going on, but historically it was often called gradient checkpointing so it persisted. Not going to change that now.
If you've tried the above additions and it works a PR would be welcome for any models that you happen to be working with.
Should use my checkpoint wrapper around the torch one (changes the reentrant arg)
from ._manipulate import checkpoint
...
def forward_intermediates(self, x, ...):
...
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
Thanks, all this sounds great. I'll submit a PR soon for just VisionTransformer, for now, and if I run across other models I need in the future I'll submit PRs for those and reference this issue.