MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

feat: add activation checkpointing to unet

Open ferreirafabio80 opened this issue 4 months ago • 4 comments

Description

Introduces an optional use_checkpointing flag in the UNet implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory.

  • Implemented via a lightweight _ActivationCheckpointWrapper wrapper around sub-blocks.
  • Checkpointing is only applied during training to avoid overhead at inference.

Types of changes

  • [x] Non-breaking change (fix or new feature that would not break existing functionality).
  • [ ] Breaking change (fix or new feature that would cause existing functionality to change).
  • [ ] New tests added to cover the changes.
  • [x] Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • [x] Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • [x] In-line docstrings updated.
  • [ ] Documentation updated, tested make html command in the docs/ folder.

ferreirafabio80 avatar Sep 03 '25 15:09 ferreirafabio80

Walkthrough

Adds monai/networks/blocks/activation_checkpointing.py implementing ActivationCheckpointWrapper that applies torch.utils.checkpoint.checkpoint(..., use_reentrant=False) to a wrapped nn.Module. Adds CheckpointUNet(UNet) in monai/networks/nets/unet.py which overrides _get_connection_block to wrap the connection subblock, down_path, and up_path with ActivationCheckpointWrapper and updates __all__ to export CheckpointUNet. Adds tests in tests/networks/nets/test_checkpointunet.py covering shape propagation (2D/3D), eval-mode equivalence to UNet, and training-time gradient behavior for checkpointed blocks.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Areas to pay attention to:

  • monai/networks/blocks/activation_checkpointing.py: correct use of torch.utils.checkpoint.checkpoint, forward signature, device/grad interactions, and use_reentrant=False.
  • monai/networks/nets/unet.py: correctness of _get_connection_block wrapping of subblock, down_path, up_path; interaction with module training/eval state and potential double-wrapping or attribute access changes.
  • tests/networks/nets/test_checkpointunet.py: test robustness for deterministic equivalence in eval mode, gradient checking in training mode, and parameterization coverage for 2D/3D and channel/stride variants.

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed Title directly summarizes the main change: adding activation checkpointing capability to UNet.
Description check ✅ Passed Description covers implementation approach, key features, and addresses all required checklist items; follows template structure.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • [ ] Create PR with unit tests
  • [ ] Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

coderabbitai[bot] avatar Sep 03 '25 15:09 coderabbitai[bot]

Hi @ferreirafabio80 thanks for the contribution but I would suggest this isn't necessarily the way to go with adapting this class. Perhaps instead you can create a subclass of UNet and override the method:

class CheckpointUNet(UNet):
    def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
        subblock = _ActivationCheckpointWrapper(subblock)
        return super()._get_connection_block(down_path, up_path, subblock)

This would suffice for your own use if you just wanted such a definition. I think the _ActivationCheckpointWrapper class may be a good thing to add to the blocks submodule instead, so it should have a public name like CheckpointWrapper.

I see also that checkpoint is used elsewhere in MONAI already like here without the checks for training and gradient that you have in your class, so I wonder if these are needed at all?

ericspod avatar Sep 26 '25 20:09 ericspod

Hi @ericspod, thank you for your comments.

Yes, that also works. I've defined a subclass and overridden the method as you suggested.

Regarding the _ActivationCheckpointWrapper class, should I create a new script in the blocks submodule or add it to an existent one?

I was probably being extremely careful with the checks in _ActivationCheckpointWrapper, but agree we can drop the checks.

ferreirafabio80 avatar Oct 01 '25 15:10 ferreirafabio80

Regarding the _ActivationCheckpointWrapper class, should I create a new script in the blocks submodule or add it to an existent one?

Sorry for the delay, I think we should put this into a new file in the monai/networks/blocks directory since it doesn't really go anywhere else. You can then give it a good name and docstring comments, plus whatever else Coderabbit has said that's reasonable. Thanks!

ericspod avatar Oct 29 '25 22:10 ericspod

@ericspod I've moved the wrapper to a different script, added docstrings and a test (which is literally a copy of the unet one). Let me know if this is sensible.

ferreirafabio80 avatar Nov 07 '25 13:11 ferreirafabio80

@ericspod I've moved the wrapper to a different script, added docstrings and a test (which is literally a copy of the unet one). Let me know if this is sensible.

This looks much better, thanks. Please do work on the testing issues and the DCO issue can be left until last.

ericspod avatar Nov 07 '25 15:11 ericspod

Thanks. I have fixed the testing issues. What is the easiest way to fix the DCO issue?

ferreirafabio80 avatar Nov 07 '25 17:11 ferreirafabio80

Thanks. I have fixed the testing issues. What is the easiest way to fix the DCO issue?

Hi @ferreirafabio80, you can refer to the guide here: https://github.com/Project-MONAI/MONAI/pull/8554/checks?check_run_id=54816594479

KumoLiu avatar Nov 10 '25 06:11 KumoLiu

thank you @KumoLiu. Just fixed it.

ferreirafabio80 avatar Nov 10 '25 09:11 ferreirafabio80

Hi @ferreirafabio80 it's looking pretty good here, if you can resolve the last conversation items with coderabbit (eg. update docstrings) and others we can merge this soon I think.

ericspod avatar Nov 14 '25 10:11 ericspod

@ericspod sorry, missed that one. All addressed now.

ferreirafabio80 avatar Nov 14 '25 10:11 ferreirafabio80

/build

KumoLiu avatar Nov 14 '25 15:11 KumoLiu