Warn that SAC + Compile for MoE models is not yet supported
Stacked PRs:
- ->#2052
Warn that SAC + Compile for MoE models is not yet supported. Behavior should be identical for moe blocks, dense blocks are no longer compiled.
This also fixes another issue: CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of apply_compile ever since https://github.com/pytorch/torchtitan/pull/1895.
what's the issue between compile + SAC + MoE?
SAC will wrap each submodule of TransformerBlock separately (_apply_op_sac_to_transformer_block_with_flex), which will make each submodule of TransformerBlock an instance of CheckpointWrapper.
This will make the isinstance() check fail and fall back to else branch, causing a compile error.
So #1895 only works with Full AC, not SAC. AC(compile(moe)) works, but SAC(compile(moe)) doesn't work.
@wwwjn According to @ezyang
https://github.com/pytorch/pytorch/pull/167844 fixes SAC around torch.compile region
So everything should be fixed now, we just need to remove the hack in _apply_op_sac_to_transformer_block_with_flex and test
fixes SAC around torch.compile region
So there are two cases here, depending on whether you care that compiling makes your graph opaque. The fix there primarily addresses one of the cases. If you're only compiling a single op like FlexAttention, it is fine to not be able to see into the graph. But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.
To check my understanding:
If you're only compiling a single op like FlexAttention, it is fine to not be able to see into the graph.
So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.
But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.
Say if we compile each transformer layers, do you mean we can only save / recompute all the ops within the transformer layer, can not specify which ops to save in SAC region?
@soulitzer
But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.
Is this full AC behavior? Or do you mean something else? Seems I was aware of this behavior before.
@wwwjn @tianyu-l yeah I think your understanding is correct - either save all activations need for backward computed within the compiled region or recompute all ops, just like full AC.
So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.
Yes, but existing policy needs to be updated to handle the inductor HOP.