[Compile] Understand why FSDP2 saves both SDPA out and wo in for bwd
With FSDP2 and transformer block compile, torch.compile saves both the SDPA output and the contiguous transposed tensor for backward:
https://github.com/pytorch/torchtitan/blob/7e93822e402c3f470bb7ddb925bbc43701bf8573/torchtitan/models/llama/model.py#L210-L213
However, with simpleFSDP with full model compile, torch.compile only saves the SDPA output. This means that FSDP2 saves an extra (bs, seq_len, dim) tensor per transformer block.
Traditionally, SDPA output is required for SDPA backward, and the input to wo is required for the wo backward. However, it may be profitable memory-wise to recompute one from the other (e.g. recompute SDPA output from undo-ing the transpose of wo input).
One question is why the activations saved for backward differ between simple FSDP with full model compile vs. FSDP2 with transformer block compile.