Qualcomm AI Engine Direct - Add MHA2SHA pass
Background
We observed that quantizing and compiling the original sha model requires a significant amount of time. Switching to the mha model speeds up this process. Therefore, we investigated whether converting the mha model after quantization is feasible. However, we cannot perform this conversion during the to_edge transformation, as splitting the convolution weights to sha would require modifying the state_dict, which is not permitted at that stage. Therefore, we decided to apply this pass during qnn_preprocess.
Summary:
- Integrated mha into sha pass and implemented it in qnn_preprocess
- Refactored mha in static llama
- Included spin quant r3 support and masked softmax for MHA model in static llama
- Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.
- Deprecated ShiftPointer kv updater mode
- Since each layer now has its own kv cache, the v cache no longer benefits from ShiftPointer, which previously avoided copying the new v cache to the input v cache. To prevent user confusion, ShiftPointer mode has been deprecated
- Applied the correct input template for smollm2 135m
- Correct the quantization annotation for reshape
- Remove outdated code from CanonicalizeConv
Results
Follow README setting, test on SM8750 with QNN 2.37. Compared the new pass convert_mha_to_sha with original sha structure
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15438
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: You can merge normally! (1 Unrelated Failure)
As of commit d261b4eac78032c1326dd3f0c9d6ccf22f89c3d9 with merge base 04f1e4d22383ffcbc770acf5002348e3f95082a2 ():
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
- pull / android / run-emulator (gh) (#16137)
Timeout waiting for emulator to boot.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
@pytorchbot label "release notes: qualcomm"
Hi @cccclai, This PR is to migrate mha2sha transformation from source level to a pass which apply on qnn_preprocess. It can significantly improve lowering time including quantization and compilation time. Could you please take a look?
Thanks
Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests
Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests
Thanks for pointing up. I have added a test case to check the functionality of MHA2SHA.
@cccclai has imported this pull request. If you are a Meta employee, you can view this in D87290896.
Hi @cccclai , I have rebased. Can I get a review on this PR?
Hi @cccclai , I have rebased. Can I get a review on this PR?
Yes sorry, I'll try do it tomorrow
Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.
I feel like I still don't follow this part, can you explain a bit?
It seems like SmartMask and ShiftPointer is no longer true after this PR, can you update https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama#kv-cache-update-mechanism and explain how it works?
Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance.
I feel like I still don't follow this part, can you explain a bit?
Sure, originally, we had n_heads * n_layer KV cache inputs and outputs. With this PR, there will be only n_layer KV cache inputs and outputs.
It seems like SmartMask and ShiftPointer is no longer true after this PR, can you update https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama#kv-cache-update-mechanism and explain how it works?
Yes, we removed shift_pointer after this PR. Since the combined kv cache doesn't benefit from shift_pointer mode, we decided to keep only SmartMask. I have also updated the README.md in this PR
the combined kv cache
Can you share how the combined kv cache work here? Is it the one you mentioned that will help the memory usage and improve runtime latency?
Can you share how the combined kv cache work here? Is it the one you mentioned that will help the memory usage and improve runtime latency?
Because we use MHA in source model level at AOT, the number of KV cache will be n_layers * [b, n_heads, cache_len, head_dim]. And then we insert slice op to split kv cache for SHA. At runtime stage, we change the number of kv cache from n_layers*n_heads to n_layers and adjust update iterations for each kv cache. The update logic uses SmartMask as before.
This change actually improves runtime performance because it decreases the number of I/O KV cache. It also reduces the time needed to register memory in QNN and set up I/O in QNN HTP. Memory usage should remain unchanged, as the total I/O size stays the same as before.
Because we use MHA in source model level at AOT, the number of KV cache will be n_layers * [b, n_heads, cache_len, head_dim]. And then we insert slice op to split kv cache for SHA. At runtime stage, we change the number of kv cache from n_layers*n_heads to n_layers and adjust update iterations for each kv cache. The update logic uses SmartMask as before.
This change actually improves runtime performance because it decreases the number of I/O KV cache. It also reduces the time needed to register memory in QNN and set up I/O in QNN HTP. Memory usage should remain unchanged, as the total I/O size stays the same as before.
I see, it seems like the diagram is also bit different than the description here because we have seperate kv cache in the diagram. Can you help updating too?
I see, it seems like the diagram is also bit different than the description here because we have seperate kv cache in the diagram. Can you help updating too?
I have updated in latest commit.
Hey can you help rebasing? Trying to check the internal signal
Hey can you help rebasing? Trying to check the internal signal
Done. Thanks :)
There are some internal errors, can you fix with the patch below?
diff --git a/executorch/backends/qualcomm/tests/TARGETS b/executorch/backends/qualcomm/tests/TARGETS
--- a/executorch/backends/qualcomm/tests/TARGETS
+++ b/executorch/backends/qualcomm/tests/TARGETS
@@ -58,6 +58,10 @@
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/backends/qualcomm/_passes:passes",
+ "//executorch/backends/qualcomm/partition:partition",
+ "//executorch/examples/models/llama:transformer_modules",
+ "//executorch/examples/qualcomm/oss_scripts/llama:masking_utils",
+ "//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
"//executorch/backends/qualcomm/builders:builders",
],
)
diff --git a/executorch/backends/qualcomm/utils/targets.bzl b/executorch/backends/qualcomm/utils/targets.bzl
--- a/executorch/backends/qualcomm/utils/targets.bzl
+++ b/executorch/backends/qualcomm/utils/targets.bzl
@@ -5,7 +5,7 @@
The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
-
+
runtime.python_library(
name = "utils",
srcs = glob([
@@ -17,6 +17,7 @@
deps = [
"//executorch/exir/backend:backend_details",
"//executorch/exir/backend:compile_spec_schema",
+ # "//executorch/backends/qualcomm/partition:partition",
"//executorch/backends/qualcomm/serialization:serialization",
],
)
diff --git a/executorch/examples/qualcomm/oss_scripts/llama/TARGETS b/executorch/examples/qualcomm/oss_scripts/llama/TARGETS
--- a/executorch/examples/qualcomm/oss_scripts/llama/TARGETS
+++ b/executorch/examples/qualcomm/oss_scripts/llama/TARGETS
@@ -8,10 +8,16 @@
runtime.python_library(
name = "static_llama",
srcs = [
+ "model/__init__.py",
+ "model/apply_rope.py",
+ "model/feed_forward.py",
+ "model/layernorm.py",
"model/static_llama.py",
],
deps = [
"//caffe2:torch",
+ "//executorch/examples/models/llama:transformer_modules",
+ "fbsource//third-party/pypi/transformers:transformers",
],
)
diff --git a/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py b/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py
--- a/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py
+++ b/executorch/examples/qualcomm/oss_scripts/llama/model/__init__.py
@@ -10,7 +10,7 @@
__all__ = [
- FeedForward_REGISTRY,
- ROTARY_EMB_REGISTRY,
- NORM_REGISTRY,
+ "FeedForward_REGISTRY",
+ "ROTARY_EMB_REGISTRY",
+ "NORM_REGISTRY",
]
There are some internal errors, can you fix with the patch below?
Fixed