executorch
executorch copied to clipboard
[MPS - DRAFT] Add support for slice_scatter; enable index_put
Summary of changes:
- support for scatter slice
- enable index put
With whole model delegation, I am seeing following crash in llama2:
in _verify_exported_program_signature
raise SpecViolationError(
torch._export.verifier.SpecViolationError: Buffer output getitem_1 does not point to a buffer that exists.
Dict of buffers that are mutated, in order: {'getitem_1': 'layers_0_attention_SDPA_kv_cache_k_cache', 'getitem': 'layers_0_attention_SDPA_kv_cache_v_cache', 'getitem_3': 'layers_1_attention_SDPA_kv_cache_k_cache', 'getitem_2': 'layers_1_attention_SDPA_kv_cache_v_cache', 'getitem_5': 'layers_2_attention_SDPA_kv_cache_k_cache', 'getitem_4': 'layers_2_attention_SDPA_kv_cache_v_cache', 'getitem_7': 'layers_3_attention_SDPA_kv_cache_k_cache', 'getitem_6': 'layers_3_attention_SDPA_kv_cache_v_cache', 'getitem_9': 'layers_4_attention_SDPA_kv_cache_k_cache', 'getitem_8': 'layers_4_attention_SDPA_kv_cache_v_cache'}
Buffer nodes available: []
Commands to lower llama2 to MPS:
- python -m examples.models.llama2.export_llama -kv --mps
- python3 -m examples.apple.mps.scripts.mps_example --model_name="llama2"
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3399
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 2 New Failures
As of commit ae4940c5a168a770363d3591d038f282dbe19d6f with merge base 87d828aabe88fd40e33fc82e03a063aec6f0b4fc ():
NEW FAILURES - The following jobs have failed:
- Apple / test-demo-ios / macos-job (gh)
RuntimeError: Command bash /Users/runner/work/_temp/exec_script failed with exit code 65 - Apple / upload-frameworks-ios (gh)
Credentials could not be loaded, please check your action inputs: Could not load credentials from any providers
This comment was automatically generated by Dr. CI and updates every 15 minutes.
I check out this pr and run
git submodule sync
git submodule update --init
./backends/apple/mps/install_requirements.sh
python -m examples.models.llama2.export_llama -kv --mps
but still can't repro...
I check out this pr and run
git submodule sync git submodule update --init ./backends/apple/mps/install_requirements.sh python -m examples.models.llama2.export_llama -kv --mpsbut still can't repro...
@cccclai could you please run ./install_requirements.sh or pip install . --no-build-isolation -v after checking out the branch? It seems it's still tracing with the old code
I see the metal kernel compilation path is not enabled. Any reason why indexing ops require metal kernels and any plan to enable metal kernel path?
Asking because I'm thinking about hooking up a int4 mm kernel using the metal kernel flow. I have the metal kernel ready and trying to figure out how to inject that into the graph builder and so on.
Have it working with this patch
diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py
index e5497389d..8e22169c0 100644
--- a/backends/apple/mps/partition/mps_partitioner.py
+++ b/backends/apple/mps/partition/mps_partitioner.py
@@ -43,12 +43,6 @@ class MPSOperatorSupport(OperatorSupportBase):
self.edge_program = edge_program
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
- # Parameters are supported if any of their users are supported
- if is_parameter(self.edge_program, node):
- return any(
- self.is_node_supported(submodules, user) for user in node.users.keys()
- )
-
if node.op != "call_function":
return False
the root cause is that we're tagging the mutable buffers.
If mps doesn't support buffer mutation, this line is good enough for tagging the constants and it will exclude the mutable buffers.
tag_constant_data(edge_program)
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@cccclai merged this pull request in pytorch/executorch@ea9647f470cf2cd5bda2b034cbf9ae9896f37039.