iree
iree copied to clipboard
Canonicalizer dropping dynamic dims in tensor.pack op
Running the canonicalizer on this pack drops a dynamic dim on the result shape:
module {
func.func @main(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
%pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
return %pack : tensor<32x7x?x16x1xf32>
}
}
After iree-opt --canonicalize pack.mlir
:
module {
func.func @main(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<32x7x32x16x1xf32>
%pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x32x16x1xf32>
%cast = tensor.cast %pack : tensor<32x7x32x16x1xf32> to tensor<32x7x?x16x1xf32>
return %cast : tensor<32x7x?x16x1xf32>
}
}
I recently added the inference to canonicalization patterns (https://github.com/llvm/llvm-project/pull/80848). What is the bug?
%pack = tensor.pack %arg0 padding_value(%cst : f32)
outer_dims_perm = [1, 2, 0]
inner_dims_pos = [2, 0]
inner_tiles = [16, 1]
into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
The size of dim1 in source shape is 32
; it is not tiled. Given that the outer_dims_perm=[1, 2, 0]
, we know the dynamic size in result shape is 32
.
I recently added the inference to canonicalization patterns (llvm/llvm-project#80848). What is the bug?
%pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
The size of dim1 in source shape is
32
; it is not tiled. Given that theouter_dims_perm=[0, 2, 1]
, we know the dynamic size in result shape is32
.
The second inner_tile is for dim0 because of the inner_dims_pos, and dim0 is dynamic.
This is fixed by https://github.com/llvm/llvm-project/pull/82539