openfold icon indicating copy to clipboard operation
openfold copied to clipboard

Explicitly set reentrant to `False` for torch checkpointing

Open an1lam opened this issue 8 months ago • 1 comments

Description

The purpose of this PR is to set up OpenFold to be compatible with torch 2.6 (technically >=2.4), in particular for using torch.compile on modules that do activation checkpointing.

As discussed here, torch 2.4 and newer require explicitly passing use_reentrant to the checkpointing function. Prior to torch 2.4 (e.g. 2.2), use_reentrant defaulted to True, however we have found that non-reentrant checkpointing works better with DDP and torch.compile. This is surprisingly hard to find documentation for, but seems to match anecdotal experience of others (ex). As a result, this change forces use_reentrant=False, enabling the use of torch.compile with our structure models.

For Discussion

If maintainers strongly prefer and are concerned about backwards compatibility, we could adapt this PR to have checkpoint_blocks take use_reentrant as a user-specified kwarg that defaults to True and then gets passed to get_checkpoint_fn. Let me know!

an1lam avatar Mar 31 '25 18:03 an1lam

@jnwei apologies for the ping out of nowhere, but are you the right person to be requesting review from? I see you're the most recent person to merge a PR.

an1lam avatar Mar 31 '25 18:03 an1lam