TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

[feat] Add dependency awareness to torch-trt partitioning

Open mfeliz-cruise opened this issue 3 years ago • 4 comments

Adds a heuristic to torch-trt partitioning's segmentation to avoid materializing segments until we hit a dependency of that segment. This can significantly reduce the number of segments/engines in cases where the linear traversal of torchscipt nodes would otherwise produce alternating torch and TRT segments which are not dependent on each-other

Fixes # (issue)

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

  • New feature (non-breaking change which adds functionality)

  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

  • This change requires a documentation update

  • [ ] My code follows the style guidelines of this project (You can use the linters)

  • [ ] I have performed a self-review of my own code

  • [ ] I have commented my code, particularly in hard-to-understand areas and hacks

  • [ ] I have made corresponding changes to the documentation

  • [ ] I have added tests to verify my fix or my feature

  • [ ] New and existing unit tests pass locally with my changes

  • [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified

mfeliz-cruise avatar Aug 23 '22 18:08 mfeliz-cruise

I think this can be handled more simply in future by directly partitioning on the dependency graph. This would require updating the min_block_size logic, but would remove the need to merge segments after the initial partition.

mfeliz-cruise avatar Aug 23 '22 20:08 mfeliz-cruise

@mfeliz-cruise we are currently working on a major restructuring of the partitioning phase to hopefully bring it closer to other design patterns in the project and make it easier to debug the state and develop new features (https://github.com/pytorch/TensorRT/pull/1263). Could you try rebasing this work on that branch and point the PR to merge into partitioning_ctx?

narendasan avatar Aug 24 '22 20:08 narendasan

@mfeliz-cruise we are currently working on a major restructuring of the partitioning phase to hopefully bring it closer to other design patterns in the project and make it easier to debug the state and develop new features (#1263). Could you try rebasing this work on that branch and point the PR to merge into partitioning_ctx?

Looks like now there are still some errors on that branch.

bowang007 avatar Aug 24 '22 22:08 bowang007

@mfeliz-cruise we are currently working on a major restructuring of the partitioning phase to hopefully bring it closer to other design patterns in the project and make it easier to debug the state and develop new features (#1263). Could you try rebasing this work on that branch and point the PR to merge into partitioning_ctx?

Looks like now there are still some errors on that branch.

I'll hold off for now until it stabilizes.

mfeliz-cruise avatar Aug 25 '22 21:08 mfeliz-cruise

@mfeliz-cruise we are currently working on a major restructuring of the partitioning phase to hopefully bring it closer to other design patterns in the project and make it easier to debug the state and develop new features (#1263). Could you try rebasing this work on that branch and point the PR to merge into partitioning_ctx?

Looks like now there are still some errors on that branch.

I'll hold off for now until it stabilizes.

I've rebased and should be ready for review.

mfeliz-cruise avatar Oct 06 '22 22:10 mfeliz-cruise

Hello @mfeliz-cruise I went through this PR and observed the test cases. I ran them and understand the final graphs you are trying to achieve post segmentation. However, I'm a little unclear on the code logic. Can you explain the heuristic and logic in a write-up ( with also some references to the modified test cases) ? Since this is advanced segmentation/merging, it would be nice for this write-up to serve as a reference for future.

peri044 avatar Oct 11 '22 00:10 peri044

Sure @peri044, do you have a standard place you put this kind of documentation or should I just expand the PR description?

mfeliz-cruise avatar Oct 11 '22 00:10 mfeliz-cruise

We keep documentation for contributors on the implementation of partitioning here: https://pytorch.org/TensorRT/contributors/partitioning.html#partitioning

narendasan avatar Oct 11 '22 00:10 narendasan

We keep documentation for contributors on the implementation of partitioning here: https://pytorch.org/TensorRT/contributors/partitioning.html#partitioning

I've taken a first pass at documenting this in docsrc/contributors/partitioning.rst

mfeliz-cruise avatar Oct 11 '22 20:10 mfeliz-cruise

  • How does the merge adjacent segments work ? What's the criteria to merge ?

peri044 avatar Oct 12 '22 18:10 peri044

  • How does the merge adjacent segments work ? What's the criteria to merge ? I added some more about this in partitioning.rst. Let me know if you have more questions @peri044. https://github.com/pytorch/TensorRT/pull/1304/files#diff-a9f595cc75ff499ecbaedbe818d92ad9543b68b00243b8da7f250f72e7f12cdcR239

mfeliz-cruise avatar Oct 12 '22 21:10 mfeliz-cruise

Hello @mfeliz-cruise , I removed this snippet and ran the tests and they ran successfully. I'm wondering what is the usecase behind this snippet ?

From my understanding this is written for in-place ops. For eg: In the sample graph

%2 = aten::cat(%1) %2 = aten::append(%2, %3) %4 = aten::relu For n = aten::append, use.user would be %2 (aten::cat output) which would not be isAfter(n) correct ? Do you have an example in mind which uses this ?

peri044 avatar Nov 02 '22 18:11 peri044

Hello @mfeliz-cruise , I removed this snippet and ran the tests and they ran successfully. I'm wondering what is the usecase behind this snippet ?

From my understanding this is written for in-place ops. For eg: In the sample graph

%2 = aten::cat(%1) %2 = aten::append(%2, %3) %4 = aten::relu For n = aten::append, use.user would be %2 (aten::cat output) which would not be isAfter(n) correct ? Do you have an example in mind which uses this ?

It would be a case like this https://github.com/pytorch/TensorRT/issues/1018 where we have an op that modifies its input without producing the modified value: = aten::_set_item(%out_dict.1, %3, %x.1) %z.1 : Tensor = aten::__getitem__(%out_dict.1, %3)

Here %out_dict.1 is modified by _set_item and we should recognize that this makes aten::_getitem__ a dependent of the set. If we only look at the node outputs here we would not identify this relationship.

mfeliz-cruise avatar Nov 02 '22 19:11 mfeliz-cruise