torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

torch.compile each TransformerBlock instead of the whole model

Open wanchaol opened this issue 10 months ago • 7 comments

This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see).

We should figure out the dynamic shape issue though

wanchaol avatar Apr 25 '24 01:04 wanchaol

Should we be able to close https://github.com/pytorch/torchtitan/issues/61 after this PR?

Also, do we need to run end-to-end numerics testing?

awgu avatar Apr 26 '24 15:04 awgu

Should we be able to close #61 after this PR?

Also, do we need to run end-to-end numerics testing?

@awgu yeah I think it should resolve that, I'll do some e2e benchmarking before landing, so this would likely take a while

wanchaol avatar Apr 26 '24 16:04 wanchaol

I'll break up some other changes to land them first

wanchaol avatar Apr 26 '24 16:04 wanchaol

when PP is present, we may torch.compile the whole stage module, which is bigger than a transformer block, i.e.

pipe = pipeline(model, ...)
stage_mod = pipe.get_stage_module(stage_idx)
stage_mod = torch.compile(stage_mod)
stage = PipelineStage(stage_mod, ...)

It would also allow the code to be more model-agnostic -- there is no transformer_block, layer_id or model.layers here.

kwen2501 avatar May 01 '24 06:05 kwen2501

In my case, enabling compilation in this way (per-layer) causes a memory leak

chrisociepa avatar May 02 '24 11:05 chrisociepa

In my case, enabling compilation in this way (per-layer) causes a memory leak

🤔 interesting, how did you observe that?

fwiw this doesn't work out of box, as it trigger some non-trival numeric issues, I'm going to leave this PR here until I resolved it. Opening a new PR to turn dynamic shape off so that it works for both 1D and 2D compile

wanchaol avatar May 02 '24 21:05 wanchaol

With each iteration, the memory usage increases and eventually results in OOM. However, just to be clear, I haven't tested this on your entire code, only on a part of it. Adding per-layer compilation causes a memory leak with each iteration. I know that the memory leak might be related to my implementation, so I just wanted to bring this issue to your attention. If you don't observe this in your code, then it's likely an issue on my end. Meanwhile, I didn't observe the numerical issues you mentioned, unless they are a direct consequence of the memory leak. The loss function appears practically the same with layer compilation both enabled and disabled.

chrisociepa avatar May 03 '24 21:05 chrisociepa

going to merge this given that:

  1. 2D compile currently broken and this PR work arounds it and make TP can be compiled again (we should separately figure out the full model compile issue) cc @bdhirsh
  2. per-TransformerBlock compile would give us potential later once the cache reusing in torch.compile enabled, it would drastically improve the compile (code start and warm start) time. cc @anijain2305

wanchaol avatar May 22 '24 05:05 wanchaol