feat: add activation checkpointing to unet
Description
Introduces an optional use_checkpointing flag in the UNet implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory.
- Implemented via a lightweight
_ActivationCheckpointWrapperwrapper around sub-blocks. - Checkpointing is only applied during training to avoid overhead at inference.
Types of changes
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running
./runtests.sh -f -u --net --coverage. - [x] Quick tests passed locally by running
./runtests.sh --quick --unittests --disttests. - [x] In-line docstrings updated.
- [ ] Documentation updated, tested
make htmlcommand in thedocs/folder.
Walkthrough
Adds monai/networks/blocks/activation_checkpointing.py implementing ActivationCheckpointWrapper that applies torch.utils.checkpoint.checkpoint(..., use_reentrant=False) to a wrapped nn.Module. Adds CheckpointUNet(UNet) in monai/networks/nets/unet.py which overrides _get_connection_block to wrap the connection subblock, down_path, and up_path with ActivationCheckpointWrapper and updates __all__ to export CheckpointUNet. Adds tests in tests/networks/nets/test_checkpointunet.py covering shape propagation (2D/3D), eval-mode equivalence to UNet, and training-time gradient behavior for checkpointed blocks.
Estimated code review effort
🎯 4 (Complex) | ⏱️ ~45 minutes
Areas to pay attention to:
- monai/networks/blocks/activation_checkpointing.py: correct use of torch.utils.checkpoint.checkpoint, forward signature, device/grad interactions, and use_reentrant=False.
- monai/networks/nets/unet.py: correctness of
_get_connection_blockwrapping of subblock, down_path, up_path; interaction with module training/eval state and potential double-wrapping or attribute access changes. - tests/networks/nets/test_checkpointunet.py: test robustness for deterministic equivalence in eval mode, gradient checking in training mode, and parameterization coverage for 2D/3D and channel/stride variants.
Pre-merge checks and finishing touches
✅ Passed checks (3 passed)
| Check name | Status | Explanation |
|---|---|---|
| Title check | ✅ Passed | Title directly summarizes the main change: adding activation checkpointing capability to UNet. |
| Description check | ✅ Passed | Description covers implementation approach, key features, and addresses all required checklist items; follows template structure. |
| Docstring Coverage | ✅ Passed | Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%. |
✨ Finishing touches
- [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.
Hi @ferreirafabio80 thanks for the contribution but I would suggest this isn't necessarily the way to go with adapting this class. Perhaps instead you can create a subclass of UNet and override the method:
class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = _ActivationCheckpointWrapper(subblock)
return super()._get_connection_block(down_path, up_path, subblock)
This would suffice for your own use if you just wanted such a definition. I think the _ActivationCheckpointWrapper class may be a good thing to add to the blocks submodule instead, so it should have a public name like CheckpointWrapper.
I see also that checkpoint is used elsewhere in MONAI already like here without the checks for training and gradient that you have in your class, so I wonder if these are needed at all?
Hi @ericspod, thank you for your comments.
Yes, that also works. I've defined a subclass and overridden the method as you suggested.
Regarding the _ActivationCheckpointWrapper class, should I create a new script in the blocks submodule or add it to an existent one?
I was probably being extremely careful with the checks in _ActivationCheckpointWrapper, but agree we can drop the checks.
Regarding the
_ActivationCheckpointWrapperclass, should I create a new script in theblockssubmodule or add it to an existent one?
Sorry for the delay, I think we should put this into a new file in the monai/networks/blocks directory since it doesn't really go anywhere else. You can then give it a good name and docstring comments, plus whatever else Coderabbit has said that's reasonable. Thanks!
@ericspod I've moved the wrapper to a different script, added docstrings and a test (which is literally a copy of the unet one). Let me know if this is sensible.
@ericspod I've moved the wrapper to a different script, added docstrings and a test (which is literally a copy of the unet one). Let me know if this is sensible.
This looks much better, thanks. Please do work on the testing issues and the DCO issue can be left until last.
Thanks. I have fixed the testing issues. What is the easiest way to fix the DCO issue?
Thanks. I have fixed the testing issues. What is the easiest way to fix the DCO issue?
Hi @ferreirafabio80, you can refer to the guide here: https://github.com/Project-MONAI/MONAI/pull/8554/checks?check_run_id=54816594479
thank you @KumoLiu. Just fixed it.
Hi @ferreirafabio80 it's looking pretty good here, if you can resolve the last conversation items with coderabbit (eg. update docstrings) and others we can merge this soon I think.
@ericspod sorry, missed that one. All addressed now.
/build