torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

static shapes propagation: add `torch-shape-refinement-pipeline` after canonicalization

Open Shukla-Gaurav opened this issue 2 years ago • 2 comments

We can simplify the IR with the help of canonicalization/folding and the result of those simplifications can be used to compute few dynamic shapes at the compile time. For example:

%int128 = torch.constant.int 128
%int1 = torch.constant.int 1
%185 = torch.vtensor.literal(opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
%186 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%195 = torch.prim.NumToTensor.Scalar %int128 : !torch.int -> !torch.vtensor<[],si64>
%196 = torch.aten.add.Tensor %195, %186, %int1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
%197 = torch.aten.Int.Tensor %196 : !torch.vtensor<[],si64> -> !torch.int
%198 = torch.aten.slice.Tensor %185, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64>
%199 = torch.aten.slice.Tensor %198, %int1, %int0, %197, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>

In this IR, it's possible to compute the only dynamic dimension in aten.slice.Tensor op. By adding proper canonicalization to aten.add.Tensor and aten.int.Tensor, we can simplify %197 as %int128 (there is an open patch for the first canonicalization https://github.com/llvm/torch-mlir/pull/935). This simplification will not help to get rid of the dynamic dimension in aten.slice.Tensor op until we run the torch-shape-refinement-pipeline again.

Whenever we canonicalize/fold an op, there is a room for shape simplification to its users. How should we achieve that? Should we add torch-shape-refinement-pipeline after canonicalization in the passes pipeline? @silvasean @ramiro050 @cathyzhyi

Shukla-Gaurav avatar Jun 22 '22 19:06 Shukla-Gaurav

@sjarus FYI this is the last dynamic shape that is required for you to lower BERT to TOSA.

powderluv avatar Jun 22 '22 20:06 powderluv

We don't want to randomly run the shape refinment pipeline too many times. Ideally we will have it in one good place in the pass pipeline.

One idea is to put the canonicalize pass here: https://github.com/llvm/torch-mlir/blob/a34dad2e077592deb497a9077fc3188b6e1154d5/lib/Dialect/Torch/Transforms/Passes.cpp#L179

The inliner which runs a few lines above that should already be running canonicalizations. So adding the explicit canonicalize pass should only be needed if SimplifyShapeCalculations is exposing this constant folding opportunity.

silvasean avatar Jun 22 '22 22:06 silvasean

We added canonicalization for aten.add.Tensor and aten.int.Tensor so this shouldn't be an issue anymore.

silvasean avatar Oct 07 '22 13:10 silvasean