[spirv] vector.shape_cast is not handled in ConvertToSPIRV
OptimizeVectorTransferPass used in the SPIR-V pipeline might generate vector.shape_cast during optimizations. For example:
https://github.com/openxla/iree/blob/6c016cac1c94ddf72314ae85053142f4b9babf87/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp#L102-L104
But ConvertToSPIRVPass calls populateVectorToSPIRVPatterns during lowering, which doesn't handle vector.shape_cast
We run into this issue when trying to drop unit dims for vector transfer in OptimizeVectorTransferPass (#13340), which creates vector.shape_cast to drop the unit dims on vectors.
Here is an example after OptimizeVectorTransferPass with dropping unit dims from #13340 and vector.shape_cast is generated during the optimization.
func.func @main_dispatch_71_generic_2x256_i8xi32() {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<[1196100044, 1139971180]> : tensor<2xi32>
%c0 = arith.constant 0 : index
%cst_0 = arith.constant dense<0> : vector<1xi32>
%c0_i8 = arith.constant 0 : i8
%cst_1 = arith.constant dense<-128> : vector<1xi32>
%cst_2 = arith.constant dense<39> : vector<1xi8>
%cst_3 = arith.constant dense<-1> : vector<1xi32>
%cst_4 = arith.constant dense<127> : vector<1xi32>
%cst_5 = arith.constant dense<-1.000000e+00> : vector<1xf32>
%cst_6 = arith.constant dense<0.0125187514> : vector<1xf32>
%c20224 = arith.constant 20224 : index
%c64 = arith.constant 64 : index
%cst_7 = arith.constant dense<[16267, -17079]> : tensor<2xi32>
%0 = bufferization.to_memref %cst_7 : memref<2xi32>
%1 = bufferization.to_memref %cst : memref<2xi32>
%2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c20224) flags(ReadOnly) : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
%3 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2xi32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<2xi32, #hal.descriptor_type<storage_buffer>>
%4 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %4, 64 : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%subview = memref.subview %4[%workgroup_id_x] [1] [1] : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>> to memref<1xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_8 = memref.subview %0[%workgroup_id_x] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: ?>>
%subview_9 = memref.subview %3[%workgroup_id_x] [1] [1] : memref<2xi32, #hal.descriptor_type<storage_buffer>> to memref<1xi32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_10 = memref.subview %2[0, %workgroup_id_x] [256, 1] [1, 1] : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>> to memref<256x1xi8, strided<[2, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%5 = vector.transfer_read %subview_10[%c0, %c0], %c0_i8 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<256x1xi8, strided<[2, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<256xi8>
%6 = arith.extsi %5 : vector<256xi8> to vector<256xi32>
%7 = vector.broadcast %6 : vector<256xi32> to vector<1x256xi32>
%8 = vector.multi_reduction <add>, %7, %cst_0 [1] : vector<1x256xi32> to vector<1xi32>
%subview_11 = memref.subview %1[%workgroup_id_x] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: ?>>
%subview_12 = memref.subview %subview_8[0] [1] [1] : memref<1xi32, strided<[1], offset: ?>> to memref<i32, strided<[], offset: ?>>
%9 = vector.transfer_read %subview_12[], %c0_i32 : memref<i32, strided<[], offset: ?>>, vector<i32>
%10 = vector.shape_cast %9 : vector<i32> to vector<1xi32>
%subview_13 = memref.subview %subview_9[0] [1] [1] : memref<1xi32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<i32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
%11 = vector.transfer_read %subview_13[], %c0_i32 : memref<i32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<i32>
%12 = vector.shape_cast %11 : vector<i32> to vector<1xi32>
%subview_14 = memref.subview %subview_11[0] [1] [1] : memref<1xi32, strided<[1], offset: ?>> to memref<i32, strided<[], offset: ?>>
%13 = vector.transfer_read %subview_14[], %c0_i32 : memref<i32, strided<[], offset: ?>>, vector<i32>
%14 = vector.shape_cast %13 : vector<i32> to vector<1xi32>
%15 = arith.muli %8, %cst_1 : vector<1xi32>
%16 = arith.subi %12, %15 : vector<1xi32>
%17 = arith.addi %10, %16 : vector<1xi32>
%18 = "tosa.apply_scale"(%17, %14, %cst_2) <{double_round = true}> : (vector<1xi32>, vector<1xi32>, vector<1xi8>) -> vector<1xi32>
%19 = arith.addi %18, %cst_3 : vector<1xi32>
%20 = arith.cmpi slt, %19, %cst_1 : vector<1xi32>
%21 = arith.select %20, %cst_1, %19 : vector<1xi1>, vector<1xi32>
%22 = arith.cmpi sgt, %19, %cst_4 : vector<1xi32>
%23 = arith.select %22, %cst_4, %21 : vector<1xi1>, vector<1xi32>
%24 = arith.trunci %23 : vector<1xi32> to vector<1xi8>
%25 = arith.sitofp %24 : vector<1xi8> to vector<1xf32>
%26 = arith.subf %25, %cst_5 : vector<1xf32>
%27 = arith.mulf %26, %cst_6 : vector<1xf32>
%subview_15 = memref.subview %subview[0] [1] [1] : memref<1xf32, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
%28 = vector.shape_cast %27 : vector<1xf32> to vector<f32>
vector.transfer_write %28, %subview_15[] : vector<f32>, memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
return
}
Reproduce
To reproduce, I dumped the IR containing vector.shape_cast just before ConvertToSPIRV: https://gist.github.com/pzread/48352affa3aa5255d285f81091e1ece8
iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-convert-to-spirv))))' sample.mlir
As chatted on Discord, we just need to add a vector to spirv pattern to handle such size-1 shape casts--they are no-op when translating to spirv given that both converts to scalar values. Assigning to @kuhar to fix it.
I'm actually more than happy to try to fix this : ) Also assign to myself
Cool, that's great! Feel free to ping @kuhar or me if you have questions then! :)
It looks like there is a size regression with drop unit dims on vector transfer + convert trivial shape_cast to no-op
https://github.com/openxla/iree/pull/14220#issuecomment-1606270765
I'll investigate it further
The problem is at VectorReductionToGPU, the patterns in mlir::vector::populatePropagateWarpVectorDistributionPatterns can't handle the vector.shape_cast, which results in bad warp distribution:
----- After VectorReduceToGPU -----
func.func @main_dispatch_84_generic_2x256_i8xi32() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<0.0125187514> : vector<1xf32>
%cst_1 = arith.constant dense<-1.000000e+00> : vector<1xf32>
%cst_2 = arith.constant dense<127> : vector<1xi32>
%cst_3 = arith.constant dense<-1> : vector<1xi32>
%cst_4 = arith.constant dense<39> : vector<1xi8>
%cst_5 = arith.constant dense<-128> : vector<1xi32>
%cst_6 = arith.constant dense<[16267, -17079]> : tensor<2xi32>
%c20224 = arith.constant 20224 : index
%c0_i8 = arith.constant 0 : i8
%cst_7 = arith.constant dense<[1196100044, 1139971180]> : tensor<2xi32>
%c0_i32 = arith.constant 0 : i32
%c64 = arith.constant 64 : index
%0 = gpu.thread_id x
%1 = bufferization.to_memref %cst_6 : memref<2xi32>
%2 = bufferization.to_memref %cst_7 : memref<2xi32>
%3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c20224) flags(ReadOnly) : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>
%4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2xi32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %4, 64 : memref<2xi32, #hal.descriptor_type<storage_buffer>>
%5 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %5, 64 : memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%6 = arith.cmpi eq, %0, %c0 : index
%alloc = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
scf.if %6 {
%9 = vector.transfer_read %3[%c0, %workgroup_id_x], %c0_i8 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<256x2xi8, strided<[2, 1], offset: 20224>, #hal.descriptor_type<storage_buffer>>, vector<256xi8>
%10 = arith.extsi %9 : vector<256xi8> to vector<256xi32>
%11 = vector.reduction <add>, %10, %c0_i32 : vector<256xi32> into i32
%12 = vector.broadcast %11 : i32 to vector<1xi32>
%13 = vector.transfer_read %1[%workgroup_id_x], %c0_i32 : memref<2xi32>, vector<i32>
%14 = vector.shape_cast %13 : vector<i32> to vector<1xi32>
%15 = vector.transfer_read %4[%workgroup_id_x], %c0_i32 : memref<2xi32, #hal.descriptor_type<storage_buffer>>, vector<i32>
%16 = vector.shape_cast %15 : vector<i32> to vector<1xi32>
%17 = vector.transfer_read %2[%workgroup_id_x], %c0_i32 : memref<2xi32>, vector<i32>
%18 = vector.shape_cast %17 : vector<i32> to vector<1xi32>
%19 = arith.muli %12, %cst_5 : vector<1xi32>
%20 = arith.subi %16, %19 : vector<1xi32>
%21 = arith.addi %14, %20 : vector<1xi32>
%22 = "tosa.apply_scale"(%21, %18, %cst_4) <{double_round = true}> : (vector<1xi32>, vector<1xi32>, vector<1xi8>) -> vector<1xi32>
%23 = arith.addi %22, %cst_3 : vector<1xi32>
%24 = arith.cmpi slt, %23, %cst_5 : vector<1xi32>
%25 = arith.select %24, %cst_5, %23 : vector<1xi1>, vector<1xi32>
%26 = arith.cmpi sgt, %23, %cst_2 : vector<1xi32>
%27 = arith.select %26, %cst_2, %25 : vector<1xi1>, vector<1xi32>
%28 = arith.trunci %27 : vector<1xi32> to vector<1xi8>
%29 = arith.sitofp %28 : vector<1xi8> to vector<1xf32>
%30 = arith.subf %29, %cst_1 : vector<1xf32>
%31 = arith.mulf %30, %cst_0 : vector<1xf32>
%32 = vector.shape_cast %31 : vector<1xf32> to vector<f32>
vector.transfer_write %32, %alloc[] : vector<f32>, memref<f32, #gpu.address_space<workgroup>>
}
gpu.barrier
%7 = vector.transfer_read %alloc[], %cst : memref<f32, #gpu.address_space<workgroup>>, vector<f32>
%8 = arith.cmpi eq, %0, %c0 : index
scf.if %8 {
vector.transfer_write %7, %5[%workgroup_id_x] : vector<f32>, memref<2xf32, strided<[1], offset: 16>, #hal.descriptor_type<storage_buffer>>
}
return
}
I'll send another patch to handle that.
The patch trying to fix vector distribution is sent out for review https://reviews.llvm.org/D154870
Unassigned myself as I'm not working on this now
Closing this for now -- no actions planned