TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Support `torch.amp.autocast` in TE checkpoint
This PR modifies te.distributed.checkpoint(...) to preserve the torch.amp.autocast(...) context from the forward pass during the recompute phase.
Reported in #787.
/te-ci pytorch
/te-ci pytorch