vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Draft][torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends

Open gshtras opened this issue 5 months ago • 1 comments

Essential Elements of an Effective PR Description Checklist

  • [x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • [x] The test plan, such as providing test command.
  • [x] The test results, such as pasting the results comparison before and after, or e2e results
  • [ ] (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

An extension of https://github.com/vllm-project/vllm/pull/16756 for V1 unified attention (and its fallback split attention) backend. Requires https://github.com/vllm-project/vllm/pull/19158 (full graph capture for this backend) to actually perform the fusion.

Fixes the fusion path to support torch.zeros initialized output tensor (used to be torch.empty before https://github.com/vllm-project/vllm/pull/19784)

Test Plan

To enable the feature in V1, the full cuda graph capture is required: -O '{"pass_config":{"enable_attn_fusion":true,"enable_noop":true},"full_cuda_graph":true}'

Test Result

Graph before fusion:

     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:228 in forward, code: value = value.view(-1, self.num_kv_heads, self.head_size)
    view_9: "bf16[s0, 8, 128]" = torch.ops.aten.reshape.default(getitem_4, [-1, 8, 128]);  getitem_4 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:224 in forward, code: output = output.view(-1, self.num_heads, self.head_size)
    full_default: "bf16[s0, 32, 128]" = torch.ops.aten.full.default([arg1_1, 32, 128], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:243 in forward, code: torch.ops.vllm.unified_attention_with_output(
    auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.unified_attention_with_output.default, query = cat, key = cat_1, value = view_9, output = full_default, layer_name = 'model.layers.0.self_attn.attn', output_scale = None);  cat = cat_1 = view_9 = full_default = None
    getitem_12: "bf16[s0, 32, 128]" = auto_functionalized_1[1];  auto_functionalized_1 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_1: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1280 in scaled_fp8_quant, code: torch.ops._C.static_scaled_fp8_quant(output, input, scale)
    view_16: "bf16[s0, 4096]" = torch.ops.aten.reshape.default(getitem_12, [-1, 4096]);  getitem_12 = None
    auto_functionalized_2 = torch.ops.higher_order.auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result = empty_1, input = view_16, scale = arg9_1);  empty_1 = view_16 = None
    getitem_14: "f8e4m3fnuz[s0, 4096]" = auto_functionalized_2[1];  auto_functionalized_2 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_2: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/model_executor/layers/quantization/utils/w8a8_utils.py:165 in rocm_per_tensor_w8a8_scaled_mm, code: output = torch._scaled_mm(qinput,
    _scaled_mm_1: "bf16[s0, 4096]" = torch.ops.aten._scaled_mm.default(getitem_14, arg10_1, arg9_1, arg11_1, None, None, torch.bfloat16);  getitem_14 = arg10_1 = arg9_1 = arg11_1 = None

Graph after fusion

     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:228 in forward, code: value = value.view(-1, self.num_kv_heads, self.head_size)
    view_9: "bf16[s0, 8, 128]" = torch.ops.aten.reshape.default(getitem_4, [-1, 8, 128]);  getitem_4 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:224 in forward, code: output = output.view(-1, self.num_heads, self.head_size)
    full_default: "bf16[s0, 32, 128]" = torch.ops.aten.full.default([arg1_1, 32, 128], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False);  full_default = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_1: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
    # No stacktrace found for following nodes
    reshape_default_62: "f8e4m3fnuz[s0, 32, 128]" = torch.ops.aten.reshape.default(empty_1, [-1, 32, 128]);  empty_1 = None
    auto_functionalized_191 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.unified_attention_with_output.default, query = cat, key = cat_1, value = view_9, output = reshape_default_62, layer_name = 'model.layers.0.self_attn.attn', output_scale = arg9_1);  cat = cat_1 = view_9 = reshape_default_62 = None
    getitem_639: "f8e4m3fnuz[s0, 32, 128]" = auto_functionalized_191[1];  auto_functionalized_191 = None
    reshape_default_63: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.reshape.default(getitem_639, [-1, 4096]);  getitem_639 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_2: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/model_executor/layers/quantization/utils/w8a8_utils.py:165 in rocm_per_tensor_w8a8_scaled_mm, code: output = torch._scaled_mm(qinput,
    _scaled_mm_1: "bf16[s0, 4096]" = torch.ops.aten._scaled_mm.default(reshape_default_63, arg10_1, arg9_1, arg11_1, None, None, torch.bfloat16);  reshape_default_63 = arg10_1 = arg9_1 = arg11_1 = None

Performance impact

TBD

gshtras avatar Jun 17 '25 20:06 gshtras

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Jun 17 '25 20:06 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jul 08 '25 15:07 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jul 13 '25 02:07 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Jul 16 '25 02:07 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Aug 01 '25 05:08 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Aug 06 '25 16:08 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Aug 28 '25 17:08 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Sep 09 '25 04:09 mergify[bot]

Looks good apart from the remaining comments. Is it possible to see unit tests running in AMD CI somewhere?

Verified this locally, until we have AMD tests running again for PRs

gshtras avatar Sep 09 '25 14:09 gshtras