openfold
openfold copied to clipboard
Explicitly set reentrant to `False` for torch checkpointing
Description
The purpose of this PR is to set up OpenFold to be compatible with torch 2.6 (technically >=2.4), in particular for using torch.compile on modules that do activation checkpointing.
As discussed here, torch 2.4 and newer require explicitly passing use_reentrant to the checkpointing function. Prior to torch 2.4 (e.g. 2.2), use_reentrant defaulted to True, however we have found that non-reentrant checkpointing works better with DDP and torch.compile. This is surprisingly hard to find documentation for, but seems to match anecdotal experience of others (ex). As a result, this change forces use_reentrant=False, enabling the use of torch.compile with our structure models.
For Discussion
If maintainers strongly prefer and are concerned about backwards compatibility, we could adapt this PR to have checkpoint_blocks take use_reentrant as a user-specified kwarg that defaults to True and then gets passed to get_checkpoint_fn. Let me know!
@jnwei apologies for the ping out of nowhere, but are you the right person to be requesting review from? I see you're the most recent person to merge a PR.