executorch icon indicating copy to clipboard operation
executorch copied to clipboard

Qualcomm AI Engine Direct - Add MHA2SHA pass

Open shewu-quic opened this issue 1 month ago • 5 comments

Background

We observed that quantizing and compiling the original sha model requires a significant amount of time. Switching to the mha model speeds up this process. Therefore, we investigated whether converting the mha model after quantization is feasible. However, we cannot perform this conversion during the to_edge transformation, as splitting the convolution weights to sha would require modifying the state_dict, which is not permitted at that stage. Therefore, we decided to apply this pass during qnn_preprocess.

Summary:

  • Integrated mha into sha pass and implemented it in qnn_preprocess
  • Refactored mha in static llama
    • Included spin quant r3 support and masked softmax for MHA model in static llama
    • Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.
  • Deprecated ShiftPointer kv updater mode
    • Since each layer now has its own kv cache, the v cache no longer benefits from ShiftPointer, which previously avoided copying the new v cache to the input v cache. To prevent user confusion, ShiftPointer mode has been deprecated
  • Applied the correct input template for smollm2 135m
  • Correct the quantization annotation for reshape
  • Remove outdated code from CanonicalizeConv

Results

Follow README setting, test on SM8750 with QNN 2.37. Compared the new pass convert_mha_to_sha with original sha structure

image

shewu-quic avatar Oct 29 '25 06:10 shewu-quic

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15438

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: You can merge normally! (1 Unrelated Failure)

As of commit d261b4eac78032c1326dd3f0c9d6ccf22f89c3d9 with merge base 04f1e4d22383ffcbc770acf5002348e3f95082a2 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 29 '25 06:10 pytorch-bot[bot]

@pytorchbot label "release notes: qualcomm"

shewu-quic avatar Oct 29 '25 06:10 shewu-quic

Hi @cccclai, This PR is to migrate mha2sha transformation from source level to a pass which apply on qnn_preprocess. It can significantly improve lowering time including quantization and compilation time. Could you please take a look?

Thanks

shewu-quic avatar Oct 30 '25 00:10 shewu-quic

Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests

cccclai avatar Oct 31 '25 18:10 cccclai

Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests

Thanks for pointing up. I have added a test case to check the functionality of MHA2SHA.

shewu-quic avatar Nov 03 '25 10:11 shewu-quic

@cccclai has imported this pull request. If you are a Meta employee, you can view this in D87290896.

meta-codesync[bot] avatar Nov 18 '25 01:11 meta-codesync[bot]

Hi @cccclai , I have rebased. Can I get a review on this PR?

shewu-quic avatar Dec 01 '25 05:12 shewu-quic

Hi @cccclai , I have rebased. Can I get a review on this PR?

Yes sorry, I'll try do it tomorrow

cccclai avatar Dec 01 '25 05:12 cccclai

Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.

I feel like I still don't follow this part, can you explain a bit?

cccclai avatar Dec 02 '25 05:12 cccclai

It seems like SmartMask and ShiftPointer is no longer true after this PR, can you update https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama#kv-cache-update-mechanism and explain how it works?

cccclai avatar Dec 02 '25 05:12 cccclai

Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.

I feel like I still don't follow this part, can you explain a bit?

Sure, originally, we had n_heads * n_layer KV cache inputs and outputs. With this PR, there will be only n_layer KV cache inputs and outputs. image

shewu-quic avatar Dec 02 '25 05:12 shewu-quic

It seems like SmartMask and ShiftPointer is no longer true after this PR, can you update https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama#kv-cache-update-mechanism and explain how it works?

Yes, we removed shift_pointer after this PR. Since the combined kv cache doesn't benefit from shift_pointer mode, we decided to keep only SmartMask. I have also updated the README.md in this PR

shewu-quic avatar Dec 02 '25 05:12 shewu-quic

the combined kv cache

Can you share how the combined kv cache work here? Is it the one you mentioned that will help the memory usage and improve runtime latency?

cccclai avatar Dec 02 '25 18:12 cccclai

Can you share how the combined kv cache work here? Is it the one you mentioned that will help the memory usage and improve runtime latency?

Because we use MHA in source model level at AOT, the number of KV cache will be n_layers * [b, n_heads, cache_len, head_dim]. And then we insert slice op to split kv cache for SHA. At runtime stage, we change the number of kv cache from n_layers*n_heads to n_layers and adjust update iterations for each kv cache. The update logic uses SmartMask as before.

This change actually improves runtime performance because it decreases the number of I/O KV cache. It also reduces the time needed to register memory in QNN and set up I/O in QNN HTP. Memory usage should remain unchanged, as the total I/O size stays the same as before.

shewu-quic avatar Dec 03 '25 01:12 shewu-quic

Because we use MHA in source model level at AOT, the number of KV cache will be n_layers * [b, n_heads, cache_len, head_dim]. And then we insert slice op to split kv cache for SHA. At runtime stage, we change the number of kv cache from n_layers*n_heads to n_layers and adjust update iterations for each kv cache. The update logic uses SmartMask as before.

This change actually improves runtime performance because it decreases the number of I/O KV cache. It also reduces the time needed to register memory in QNN and set up I/O in QNN HTP. Memory usage should remain unchanged, as the total I/O size stays the same as before.

I see, it seems like the diagram is also bit different than the description here because we have seperate kv cache in the diagram. Can you help updating too?

cccclai avatar Dec 04 '25 02:12 cccclai

I see, it seems like the diagram is also bit different than the description here because we have seperate kv cache in the diagram. Can you help updating too?

I have updated in latest commit.

shewu-quic avatar Dec 04 '25 03:12 shewu-quic

Hey can you help rebasing? Trying to check the internal signal

cccclai avatar Dec 09 '25 00:12 cccclai

Hey can you help rebasing? Trying to check the internal signal

Done. Thanks :)

shewu-quic avatar Dec 09 '25 00:12 shewu-quic

There are some internal errors, can you fix with the patch below?

diff --git a/executorch/backends/qualcomm/tests/TARGETS b/executorch/backends/qualcomm/tests/TARGETS
--- a/executorch/backends/qualcomm/tests/TARGETS
+++ b/executorch/backends/qualcomm/tests/TARGETS
@@ -58,6 +58,10 @@
         "//caffe2:torch",
         "//executorch/exir:lib",
         "//executorch/backends/qualcomm/_passes:passes",
+        "//executorch/backends/qualcomm/partition:partition",
+        "//executorch/examples/models/llama:transformer_modules",
+        "//executorch/examples/qualcomm/oss_scripts/llama:masking_utils",
+        "//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
         "//executorch/backends/qualcomm/builders:builders",
     ],
 )
diff --git a/executorch/backends/qualcomm/utils/targets.bzl b/executorch/backends/qualcomm/utils/targets.bzl
--- a/executorch/backends/qualcomm/utils/targets.bzl
+++ b/executorch/backends/qualcomm/utils/targets.bzl
@@ -5,7 +5,7 @@
     The directory containing this targets.bzl file should also contain both
     TARGETS and BUCK files that call this function.
     """
-    
+
     runtime.python_library(
         name = "utils",
         srcs = glob([
@@ -17,6 +17,7 @@
         deps = [
             "//executorch/exir/backend:backend_details",
             "//executorch/exir/backend:compile_spec_schema",
+            # "//executorch/backends/qualcomm/partition:partition",
             "//executorch/backends/qualcomm/serialization:serialization",
         ],
     )
diff --git a/executorch/examples/qualcomm/oss_scripts/llama/TARGETS b/executorch/examples/qualcomm/oss_scripts/llama/TARGETS
--- a/executorch/examples/qualcomm/oss_scripts/llama/TARGETS
+++ b/executorch/examples/qualcomm/oss_scripts/llama/TARGETS
@@ -8,10 +8,16 @@
 runtime.python_library(
     name = "static_llama",
     srcs = [
+        "model/__init__.py",
+        "model/apply_rope.py",
+        "model/feed_forward.py",
+        "model/layernorm.py",
         "model/static_llama.py",
     ],
     deps = [
         "//caffe2:torch",
+        "//executorch/examples/models/llama:transformer_modules",
+        "fbsource//third-party/pypi/transformers:transformers",
     ],
 )
 
diff --git a/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py b/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py
--- a/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py
+++ b/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py
@@ -10,7 +10,7 @@
 
 
 __all__ = [
-    FeedForward_REGISTRY,
-    ROTARY_EMB_REGISTRY,
-    NORM_REGISTRY,
+    "FeedForward_REGISTRY",
+    "ROTARY_EMB_REGISTRY",
+    "NORM_REGISTRY",
 ]

cccclai avatar Dec 09 '25 22:12 cccclai

There are some internal errors, can you fix with the patch below?

Fixed

shewu-quic avatar Dec 10 '25 01:12 shewu-quic