feat:[AutoDeploy] utilize torch._inductor.pattern_matcher to write pattern matcher
Description
- An example of using torch._inductor.pattern_matcher to match RoPE with explicit cos/sin pattern
- Wrap up in utility file
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run /bot [-h|--help] to print this help message.
See details below for each supported subcommand.
run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]
Launch build/test pipelines. All previously running jobs will be killed.
--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.
--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.
--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.
--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.
--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.
--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.
--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.
--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.
--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".
kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
@Fridah-nv, I have found one interesting corner case that throws an error at the moment:
when we have a pattern matched node from a previous pattern matcher and then have another pattern matcher that uses that pattern matched node as input, there is an error that gets thrown because the FakeMode is not the same instance. You might be able to reproduce it as well with the proposed double rope pattern matching I proposed for deepseek rope.
I have noticed it for match_repeat_kv + match_eager_attention. I can provide a more detailed example to reproduce later if you want. Just wanted to let you know in the meantime
when we have a pattern matched node from a previous pattern matcher and then have another pattern matcher that uses that pattern matched node as input, there is an error that gets thrown because the FakeMode is not the same instance.
I'm not fully understanding this, what's the FakeMode you referred to?
Can this error be avoided if using two instances of PatternMatcherPass?
def _interleaved_rope_pattern2(q, k, cos, sin, unsqueeze_dim=1): b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) b, h, s, d = k.shape k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) return torch.ops.rope.torch_apply_rope_with_explicit_cos_sin.default( q, k, cos, sin, unsqueeze_dim )
This does nor work for me as it generates two additional get_item nodes in the exported search graph:
def forward(self, args_0, args_1, args_2, args_3): args_0, args_1, args_2, args_3, = fx_pytree.tree_flatten_spec(([args_0, args_1, args_2, args_3], {}), self._in_spec) view = torch.ops.aten.view.default(args_0, [8, 8, 16, 32, 2]); args_0 = None view_1 = torch.ops.aten.view.default(args_1, [8, 8, 16, 32, 2]); args_1 = None transpose = torch.ops.aten.transpose.int(view, 4, 3); view = None transpose_1 = torch.ops.aten.transpose.int(view_1, 4, 3); view_1 = None reshape = torch.ops.aten.reshape.default(transpose, [8, 8, 16, 64]); transpose = None reshape_1 = torch.ops.aten.reshape.default(transpose_1, [8, 8, 16, 64]); transpose_1 = None torch_apply_rope_with_explicit_cos_sin = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin.default(reshape, reshape_1, args_2, args_3); reshape = reshape_1 = args_2 = args_3 = None getitem = torch_apply_rope_with_explicit_cos_sin[0] getitem_1 = torch_apply_rope_with_explicit_cos_sin[1]; torch_apply_rope_with_explicit_cos_sin = None return pytree.tree_unflatten((getitem, getitem_1), self._out_spec)
merged in https://github.com/nv-auto-deploy/TensorRT-LLM/pull/7