torch-mlir
torch-mlir copied to clipboard
static shapes propagation: add `torch-shape-refinement-pipeline` after canonicalization
We can simplify the IR with the help of canonicalization/folding and the result of those simplifications can be used to compute few dynamic shapes at the compile time. For example:
%int128 = torch.constant.int 128
%int1 = torch.constant.int 1
%185 = torch.vtensor.literal(opaque<"elided_large_const", "0xDEADBEEF"> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
%186 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%195 = torch.prim.NumToTensor.Scalar %int128 : !torch.int -> !torch.vtensor<[],si64>
%196 = torch.aten.add.Tensor %195, %186, %int1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
%197 = torch.aten.Int.Tensor %196 : !torch.vtensor<[],si64> -> !torch.int
%198 = torch.aten.slice.Tensor %185, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512],si64>
%199 = torch.aten.slice.Tensor %198, %int1, %int0, %197, %int1 : !torch.vtensor<[1,512],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
In this IR, it's possible to compute the only dynamic dimension in aten.slice.Tensor
op. By adding proper canonicalization to aten.add.Tensor and aten.int.Tensor
, we can simplify %197 as %int128 (there is an open patch for the first canonicalization https://github.com/llvm/torch-mlir/pull/935). This simplification will not help to get rid of the dynamic dimension in aten.slice.Tensor
op until we run the torch-shape-refinement-pipeline
again.
Whenever we canonicalize/fold an op, there is a room for shape simplification to its users. How should we achieve that?
Should we add torch-shape-refinement-pipeline after canonicalization
in the passes pipeline?
@silvasean @ramiro050 @cathyzhyi
@sjarus FYI this is the last dynamic shape that is required for you to lower BERT to TOSA.
We don't want to randomly run the shape refinment pipeline too many times. Ideally we will have it in one good place in the pass pipeline.
One idea is to put the canonicalize pass here: https://github.com/llvm/torch-mlir/blob/a34dad2e077592deb497a9077fc3188b6e1154d5/lib/Dialect/Torch/Transforms/Passes.cpp#L179
The inliner which runs a few lines above that should already be running canonicalizations. So adding the explicit canonicalize pass should only be needed if SimplifyShapeCalculations is exposing this constant folding opportunity.
We added canonicalization for aten.add.Tensor and aten.int.Tensor so this shouldn't be an issue anymore.