torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Warn that SAC + Compile for MoE models is not yet supported

Open xmfan opened this issue 1 month ago • 6 comments

Stacked PRs:

  • ->#2052

Warn that SAC + Compile for MoE models is not yet supported. Behavior should be identical for moe blocks, dense blocks are no longer compiled.

image

This also fixes another issue: CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of apply_compile ever since https://github.com/pytorch/torchtitan/pull/1895.

xmfan avatar Nov 18 '25 01:11 xmfan

what's the issue between compile + SAC + MoE?

SAC will wrap each submodule of TransformerBlock separately (_apply_op_sac_to_transformer_block_with_flex), which will make each submodule of TransformerBlock an instance of CheckpointWrapper.

This will make the isinstance() check fail and fall back to else branch, causing a compile error.

So #1895 only works with Full AC, not SAC. AC(compile(moe)) works, but SAC(compile(moe)) doesn't work.

wwwjn avatar Nov 18 '25 05:11 wwwjn

@wwwjn According to @ezyang

https://github.com/pytorch/pytorch/pull/167844 fixes SAC around torch.compile region

So everything should be fixed now, we just need to remove the hack in _apply_op_sac_to_transformer_block_with_flex and test

tianyu-l avatar Nov 18 '25 06:11 tianyu-l

fixes SAC around torch.compile region

So there are two cases here, depending on whether you care that compiling makes your graph opaque. The fix there primarily addresses one of the cases. If you're only compiling a single op like FlexAttention, it is fine to not be able to see into the graph. But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.

soulitzer avatar Nov 18 '25 16:11 soulitzer

To check my understanding:

If you're only compiling a single op like FlexAttention, it is fine to not be able to see into the graph.

So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.

But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.

Say if we compile each transformer layers, do you mean we can only save / recompute all the ops within the transformer layer, can not specify which ops to save in SAC region?

wwwjn avatar Nov 18 '25 18:11 wwwjn

@soulitzer

But for larger graphs, SAC(compile(fn will work, but it might not do exactly what you want. You'll only be able to save/recompute at the granularity of that whole graph.

Is this full AC behavior? Or do you mean something else? Seems I was aware of this behavior before.

tianyu-l avatar Nov 18 '25 18:11 tianyu-l

@wwwjn @tianyu-l yeah I think your understanding is correct - either save all activations need for backward computed within the compiled region or recompute all ops, just like full AC.

So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.

Yes, but existing policy needs to be updated to handle the inductor HOP.

soulitzer avatar Nov 18 '25 20:11 soulitzer