iree
iree copied to clipboard
Improve the vectorization for reverse-like tensor.extract op
We observed that the vectorization of reverse-like tensor.extract op was wrong in https://github.com/openxla/iree/issues/16544.
Input:
func.func @foo_dispatch_0_generic_2x1x3_f32() {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x2x3xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x1x3xf32>>
%2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 1, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<2x1x3xf32>> -> tensor<2x1x3xf32>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 2, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x2x3xf32>> -> tensor<1x2x3xf32>
%4 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg1 = %2) -> (tensor<2x1x3xf32>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0, 0, 0] [1, 1, 3] [1, 1, 1] : tensor<2x1x3xf32> to tensor<1x1x3xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%extracted_slice : tensor<1x1x3xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0], [1, 1, 4], [0, 0, 0], [0, 0, 0]]>} {
^bb0(%out: f32):
%6 = linalg.index 1 : index
%7 = linalg.index 0 : index
%8 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%6, %7, %arg0)
%9 = linalg.index 2 : index
%10 = arith.subi %c2, %9 : index
%extracted = tensor.extract %3[%c0, %8, %10] : tensor<1x2x3xf32>
linalg.yield %extracted : f32
} -> tensor<1x1x3xf32>
%inserted_slice = tensor.insert_slice %5 into %arg1[%arg0, 0, 0] [1, 1, 3] [1, 1, 1] : tensor<1x1x3xf32> into tensor<2x1x3xf32>
scf.yield %inserted_slice : tensor<2x1x3xf32>
}
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [2, 1, 3], strides = [1, 1, 1] : tensor<2x1x3xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x1x3xf32>>
return
}
Old output:
module {
func.func @foo_dispatch_0_generic_2x1x3_f32() {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 0.000000e+00 : f32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x2x3xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x1x3xf32>>
%2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 1, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<2x1x3xf32>> -> tensor<2x1x3xf32>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 2, 3], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x2x3xf32>> -> tensor<1x2x3xf32>
%4 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg1 = %2) -> (tensor<2x1x3xf32>) {
%extracted_slice = tensor.extract_slice %arg1[%arg0, 0, 0] [1, 1, 3] [1, 1, 1] : tensor<2x1x3xf32> to tensor<1x1x3xf32>
%5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1>
%6 = vector.broadcast %arg0 : index to vector<1x1x4xindex>
%7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
%8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex>
%9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32>
%10 = vector.transfer_write %9, %extracted_slice[%c0, %c0, %c0], %5 {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x3xf32>
%inserted_slice = tensor.insert_slice %10 into %arg1[%arg0, 0, 0] [1, 1, 3] [1, 1, 1] : tensor<1x1x3xf32> into tensor<2x1x3xf32>
scf.yield %inserted_slice : tensor<2x1x3xf32>
}
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [2, 1, 3], strides = [1, 1, 1] : tensor<2x1x3xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x1x3xf32>>
return
}
}
Looking at the original tensor.extract. What we want is “2, 1, 0” rather than “2, 3, 4”. I provided a fix which makes it fall in gather solution. It is always correct if we go with vector.gather. However, it is actually a contiguous load with a reverse at vector level. One of potential solutions is to detect the pattern, load the whole slice, and reverse it at vector level.
Filing an issue so that we don't miss this case.