executorch icon indicating copy to clipboard operation
executorch copied to clipboard

[MPS - DRAFT] Add support for slice_scatter; enable index_put

Open DenisVieriu97 opened this issue 1 year ago • 5 comments

Summary of changes:

  • support for scatter slice
  • enable index put

With whole model delegation, I am seeing following crash in llama2:

in _verify_exported_program_signature
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: Buffer output getitem_1 does not point to a buffer that exists.
Dict of buffers that are mutated, in order: {'getitem_1': 'layers_0_attention_SDPA_kv_cache_k_cache', 'getitem': 'layers_0_attention_SDPA_kv_cache_v_cache', 'getitem_3': 'layers_1_attention_SDPA_kv_cache_k_cache', 'getitem_2': 'layers_1_attention_SDPA_kv_cache_v_cache', 'getitem_5': 'layers_2_attention_SDPA_kv_cache_k_cache', 'getitem_4': 'layers_2_attention_SDPA_kv_cache_v_cache', 'getitem_7': 'layers_3_attention_SDPA_kv_cache_k_cache', 'getitem_6': 'layers_3_attention_SDPA_kv_cache_v_cache', 'getitem_9': 'layers_4_attention_SDPA_kv_cache_k_cache', 'getitem_8': 'layers_4_attention_SDPA_kv_cache_v_cache'}
Buffer nodes available: []

Commands to lower llama2 to MPS:

  • python -m examples.models.llama2.export_llama -kv --mps
  • python3 -m examples.apple.mps.scripts.mps_example --model_name="llama2"

DenisVieriu97 avatar Apr 29 '24 17:04 DenisVieriu97

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3399

Note: Links to docs will display an error until the docs builds have been completed.

:x: 2 New Failures

As of commit ae4940c5a168a770363d3591d038f282dbe19d6f with merge base 87d828aabe88fd40e33fc82e03a063aec6f0b4fc (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 29 '24 17:04 pytorch-bot[bot]

I check out this pr and run

git submodule sync
git submodule update --init
./backends/apple/mps/install_requirements.sh
python -m examples.models.llama2.export_llama -kv --mps

but still can't repro...

cccclai avatar May 01 '24 06:05 cccclai

I check out this pr and run

git submodule sync
git submodule update --init
./backends/apple/mps/install_requirements.sh
python -m examples.models.llama2.export_llama -kv --mps

but still can't repro...

@cccclai could you please run ./install_requirements.sh or pip install . --no-build-isolation -v after checking out the branch? It seems it's still tracing with the old code

DenisVieriu97 avatar May 01 '24 23:05 DenisVieriu97

I see the metal kernel compilation path is not enabled. Any reason why indexing ops require metal kernels and any plan to enable metal kernel path?

Asking because I'm thinking about hooking up a int4 mm kernel using the metal kernel flow. I have the metal kernel ready and trying to figure out how to inject that into the graph builder and so on.

larryliu0820 avatar May 02 '24 20:05 larryliu0820

Have it working with this patch

diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py
index e5497389d..8e22169c0 100644
--- a/backends/apple/mps/partition/mps_partitioner.py
+++ b/backends/apple/mps/partition/mps_partitioner.py
@@ -43,12 +43,6 @@ class MPSOperatorSupport(OperatorSupportBase):
         self.edge_program = edge_program

     def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
-        # Parameters are supported if any of their users are supported
-        if is_parameter(self.edge_program, node):
-            return any(
-                self.is_node_supported(submodules, user) for user in node.users.keys()
-            )
-
         if node.op != "call_function":
             return False

the root cause is that we're tagging the mutable buffers.

If mps doesn't support buffer mutation, this line is good enough for tagging the constants and it will exclude the mutable buffers.

tag_constant_data(edge_program)

cccclai avatar May 03 '24 06:05 cccclai

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar May 13 '24 20:05 facebook-github-bot

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar May 13 '24 20:05 facebook-github-bot

@cccclai merged this pull request in pytorch/executorch@ea9647f470cf2cd5bda2b034cbf9ae9896f37039.

facebook-github-bot avatar May 13 '24 21:05 facebook-github-bot