torchtitan
torchtitan copied to clipboard
torch.compile each TransformerBlock instead of the whole model
This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see).
We should figure out the dynamic shape issue though
Should we be able to close https://github.com/pytorch/torchtitan/issues/61 after this PR?
Also, do we need to run end-to-end numerics testing?
Should we be able to close #61 after this PR?
Also, do we need to run end-to-end numerics testing?
@awgu yeah I think it should resolve that, I'll do some e2e benchmarking before landing, so this would likely take a while
I'll break up some other changes to land them first
when PP is present, we may torch.compile the whole stage module, which is bigger than a transformer block, i.e.
pipe = pipeline(model, ...)
stage_mod = pipe.get_stage_module(stage_idx)
stage_mod = torch.compile(stage_mod)
stage = PipelineStage(stage_mod, ...)
It would also allow the code to be more model-agnostic -- there is no transformer_block
, layer_id
or model.layers
here.
In my case, enabling compilation in this way (per-layer) causes a memory leak
In my case, enabling compilation in this way (per-layer) causes a memory leak
🤔 interesting, how did you observe that?
fwiw this doesn't work out of box, as it trigger some non-trival numeric issues, I'm going to leave this PR here until I resolved it. Opening a new PR to turn dynamic shape off so that it works for both 1D and 2D compile
With each iteration, the memory usage increases and eventually results in OOM. However, just to be clear, I haven't tested this on your entire code, only on a part of it. Adding per-layer compilation causes a memory leak with each iteration. I know that the memory leak might be related to my implementation, so I just wanted to bring this issue to your attention. If you don't observe this in your code, then it's likely an issue on my end. Meanwhile, I didn't observe the numerical issues you mentioned, unless they are a direct consequence of the memory leak. The loss function appears practically the same with layer compilation both enabled and disabled.
going to merge this given that:
- 2D compile currently broken and this PR work arounds it and make TP can be compiled again (we should separately figure out the full model compile issue) cc @bdhirsh
- per-TransformerBlock compile would give us potential later once the cache reusing in torch.compile enabled, it would drastically improve the compile (code start and warm start) time. cc @anijain2305