torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

issues on llama3 compile + (async) TP + AC

Open tianyu-l opened this issue 7 months ago • 20 comments

Bug description

on 8 GPUs, DP2 TP4

  1. compile + selective op AC + TP: got failure
File "/data/users/lty/pytorch/torch/_ops.py", line 1317, in __getattr__
      raise AttributeError(
  AttributeError: '_OpNamespace' 'symm_mem' object has no attribute 'fused_all_gather_matmul'

Note: DP4 TP2 works.

  1. compile + selective op AC + async TP: got very low throughput, compared with "full AC + TP"

  2. compile + full AC + asyncTP: got the following warning and very low throughput (compared with "full AC + TP")

torch/_inductor/fx_passes/micro_pipeline_tp.py:894] [0/1] no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion
  1. compile + async TP (+ selective 2 ac) is still failing on CI machines https://github.com/pytorch/torchtitan/actions/runs/14992456541/job/42118899398?pr=1186

Versions

latest pytorch built from source on A100 GPUs debug model

tianyu-l avatar May 13 '25 08:05 tianyu-l

cc @kwen2501 @danielvegamyhre @bdhirsh

tianyu-l avatar May 13 '25 08:05 tianyu-l

AttributeError: '_OpNamespace' 'symm_mem' object has no attribute 'fused_all_gather_matmul'

Hmm, I've never seen this before, seems like the custom op isn't registered somehow?

torch/_inductor/fx_passes/micro_pipeline_tp.py:894] [0/1] no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion

This warning is expected, many valid graphs will not have candidates for matmul reduce scatter fusion. If I remember correctly the backward pass is often like this because we are adding gradients then doing reduce-scatter, which doesn't match the matmul reduce scatter pattern.

compile + async TP (+ selective 2 ac) is still failing on CI machines https://github.com/pytorch/torchtitan/actions/runs/14992456541/job/42118899398?pr=1186

This looks like a CUDA driver error on the host machine.

    File "/opt/conda/envs/py_3.11/lib/python3.11/site-packages/torch/distributed/_symmetric_memory/__init__.py", line 120, in get_symm_mem_workspace
      return _SymmetricMemory.rendezvous(tensor)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  RuntimeError: CUDA driver error: invalid device ordinal

danielvegamyhre avatar May 13 '25 15:05 danielvegamyhre

on 8 GPUs, DP2 TP4 compile + selective op AC + TP: got failure

@tianyu-l What's your command? I could not reproduce the same issue on A100 nor on H100. This is my command:

CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh  --training.compile --parallelism.tensor_parallel_degree 4 --parallelism.enable_async_tensor_parallel

As for the performance issue, like @danielvegamyhre mentioned this is expected due to the "noise" nodes in the graph. Luca has mentioned some ideas to improve async TP, which we will explore next.

The CI error seems to be a transient infra/machine error as it is green and the original error seems to be related to the infra/machine setup error.

fegin avatar May 14 '25 06:05 fegin

@fegin

This is my command:

You also need selective_ac_option = 'op' instead of the default 2

The CI error seems to be a transient infra/machine error as it is green and the original error seems to be related to the infra/machine setup error.

Async TP CI has been commented out for months. Even if it's related to CI machine, we should still fix.

tianyu-l avatar May 14 '25 06:05 tianyu-l

@tianyu-l

You also need selective_ac_option = 'op' instead of the default 2

That's the default value. 2 is not default. And I verified that even if I specify --activation_checkpoint.selective_ac_option="op", it still works.

fegin avatar May 14 '25 16:05 fegin

@fegin After I switch to another H100 machine, it works for me.

That's the default value. 2 is not default.

oh I ran on the debugmodel, where default is 2. https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/train_configs/debug_model.toml#L68

tianyu-l avatar May 15 '25 02:05 tianyu-l

@tianyu-l I couldn't repro this error, using either per op SAC or per layer SAC. I suspect it is a pytorch build issue, since that custom op definitely exists. I would try a fresh re-install w/ 2.6.0+ (not sure oldest torch version that would include these ops?) or fresh build from source.

File "/data/users/lty/pytorch/torch/_ops.py", line 1317, in __getattr__
      raise AttributeError(
  AttributeError: '_OpNamespace' 'symm_mem' object has no attribute 'fused_all_gather_matmul'

danielvegamyhre avatar May 15 '25 15:05 danielvegamyhre

Update: @tianyu-l I actually accidentally reproed the no attribute 'fused_all_gather_matmul' error, strangely not with async TP but with vanilla TP. To me this indicates it is an inductor caching issue, since the original dynamo traced graph (before micro_pipeline_tp post grad pass) doesn't change between async TP and vanilla TP, so it sees a cached inductor graph/codegen which matches the trace and tries to run it. However, the cached inductor codegen is the product of a post-grad graph manipulation where symmetric memory was used, and now with vanilla tp the symmetric memory group is not initialized so the custom op is not registered.

Clearing the inductor cache and rerunning resolved the issue.

rm -rf /tmp/torchinductor_${USER}; CONFIG_FILE="torchtitan/models/llama3/train_
configs/llama3_8b.toml " ./run_train.sh  --training.compile --parallelism.tensor_parallel_degree 4

I can see this potentially confusing users though, maybe we should think about how to ensure this inductor cache from async TP isn't used for TP? Maybe @bdhirsh has thoughts on how to do this?

danielvegamyhre avatar May 15 '25 16:05 danielvegamyhre

@danielvegamyhre Very interesting finding. I thought inductor would prefix the timestamp to the cache folder. But it looks like I'm wrong.

fegin avatar May 15 '25 17:05 fegin

Nice find. @danielvegamyhre are there a set of commands that I can try running to repro? How did you e.g. get asyncTP to not run the first time but run the second time?

bdhirsh avatar May 15 '25 20:05 bdhirsh

Nice find. @danielvegamyhre are there a set of commands that I can try running to repro? How did you e.g. get asyncTP to not run the first time but run the second time?

Sure, here is a repro:

  1. Run async TP job (I cleared cache first to start repro w/ clean slate): rm -rf /tmp/torchinductor_${USER}; NGPU=4 CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --training.compile --parallelism.tensor_parallel_degree 2 --parallelism.enable_async_tensor_parallel

  2. Run vanilla TP job be running command WITHOUT clearing cache and WITHOUT async TP flag: NGPU=4 CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --training.compile --parallelism.tensor_parallel_degree 2

danielvegamyhre avatar May 15 '25 21:05 danielvegamyhre

It looks like we aren't including private inductor configs in the cache key. I need to make changes to the PR but the fix is here: https://github.com/pytorch/pytorch/pull/153672/

bdhirsh avatar May 16 '25 00:05 bdhirsh

It looks like we aren't including private inductor configs in the cache key. I need to make changes to the PR but the fix is here: pytorch/pytorch#153672

that was fast! thanks for fixing that

danielvegamyhre avatar May 16 '25 00:05 danielvegamyhre

@tianyu-l @fegin of the items listed in the original issue, I believe the current state can now be summarized as:

  • AsyncTP + selective per op works, but throughput is lower than expected
  • AsyncTP CI tests broken

danielvegamyhre avatar May 19 '25 23:05 danielvegamyhre

@danielvegamyhre thanks for the update! How about AsyncTP + full AC? Is the throughput also lower than expected?

tianyu-l avatar May 20 '25 01:05 tianyu-l

@danielvegamyhre thanks for the update! How about AsyncTP + full AC? Is the throughput also lower than expected?

When I tested before with Llama3 70b the perf improvements with async TP were solid, around ~12-15% TPS with bf16 and float8 tensorwise. Note data below is from March 25th after I fixed some issues. We may want to rerun these benchmarks periodically, perhaps in a nighlty CI job?

Llama3 70b training runs on 128 H100s with full AC, using FSDP=16, TP=8

  • bf16 (vanilla TP): 598 TPS, peak memory 71.51 GB
  • bf16 (async TP): TPS 673, peak memory 71.08 (+12.54% TPS vs vanilla TP)
  • float8 tensorwise (vanilla TP): 820 TPS, peak memory 55.26 GB
  • float8 tensorwise (async TP): 950 TPS, peak memory 55.91 GB (+15.85% TPS vs vanilla TP)
  • float8 rowwise (vanilla TP): TPS: 540 TPS, peak memory 71.46 GB
  • float8 rowwise (async TP): 560 TPS, peak memory 70.65 GB (+3.7% TPS vs vanilla TP but still unexpectedly lower than bf16)

Float8 rowwise yields poor perf baseline with vanilla TP and a smaller speedup with async TP (due to the issue described here https://github.com/pytorch/pytorch/issues/149990). We should look into why float8 rowwise vanilla TP baseline is so low, though. I just reran vanilla TP benchmarks locally with Llama3 8b with FSDP4, TP=2 and found vanilla TP perf is also bad.

It can't be the case the problem size isn't large enough after TP sharding for float8 rowwise be a net benefit, since we see the same issue for Llama3 70b and the GEMM sizes are much larger. I created an issue for this: https://github.com/pytorch/torchtitan/issues/1207

danielvegamyhre avatar May 20 '25 03:05 danielvegamyhre

I did some tests with the latest PyTorch and TorchTitan. The result contradicts with some observations above. For llama3 8b, full AC, TP8 the performance is quite bad w/ or w/o async TP. On the other hand, we can see some performance improvement for other settings (TP2 or TP4) with async TP. Need more investigations on this.

TorchTitan hash: 5de055e30e4b288c646c2dffe3fdd8d18d1cdab9 (5/20) PyTorch hash: b0e5402377c0296c45ca7ae8a944427616521604 (5/19) H100, llama3 8b, full AC, FSDP:8, bf16, local batch size: 2, global batch size: 16

  • With torch.compile(): 6,068 TPS, 32.91 GB memory

H100, llama3 8b, full AC, FSDP:4, TP:2, bf16, local batch size: 4, global batch size: 16

  • Vanilla TP (with torch.compile()): 5,681 TPS, 42.50 GB memory
  • Async TP: 5,826 TPS, 42.62 GB memory

H100, llama3 8b, full AC, FSDP:2, TP:4, bf16, local batch size: 8, global batch size: 16

  • Vanilla TP (with torch.compile()): 5,120 TPS, 41.98 GB memory
  • Async TP: 5,532 TPS, 42.06 GB memory

H100, llama3 8b, full AC, TP:8, bf16, local batch size: 16, global batch size: 16

  • Vanilla TP (with torch.compile()): 576 TPS, 49 GB memory
  • Vanilla TP (without torch.compile()): 564 TPS, 56.67 GB memory
  • Async TP: 595 TPS, 50.68 GB memory

fegin avatar May 20 '25 20:05 fegin

@fegin With TP8 I'm assuming you are not using FSDP. By default no Mixed Precision Training is used. Did you manually enabled bf16 training?

tianyu-l avatar May 21 '25 00:05 tianyu-l

@tianyu-l Good point, I forgot this every time :( Yes, that may be the root cause. I'll verify that. Besides that, I think other performances are reasonable. I noticed that your original report was on A100, which has a different intra node connection from H100. That may be the reason why you are seeing some performance issue. The reduce_scatter fusion is something we could improve if it can be fused but it is more like a missing feature, rather than a bug, from my perspective.

The main issue is still CI.

fegin avatar May 21 '25 00:05 fegin

For the CI issue, the error is consistently failed on a CUDA driver API to set a virtual address, which makes me think that this may be related to machine settings. As we are going to get H100 next week, I'll wait for the machine and retest it

fegin avatar May 23 '25 05:05 fegin

@danielvegamyhre finally managed to merge the cache fix - does that cover this issue? https://github.com/pytorch/pytorch/pull/153672

also cc @xmfan

bdhirsh avatar Jun 11 '25 18:06 bdhirsh

@danielvegamyhre finally managed to merge the cache fix - does that cover this issue? pytorch/pytorch#153672

also cc @xmfan

Yep, using the latest torch nightly build I'm no longer able to repro this - thanks!

danielvegamyhre avatar Jun 11 '25 20:06 danielvegamyhre