iree
iree copied to clipboard
Slow quantized matmul in MobileBert on c2-standard-16
IREE is 2x slower than baseline on c2-standard-16 for single-threaded (526 ms v.s. 278 ms). Some quantized GEMMs are very slow in this case.
dispatch_8_matmul_384x128x512
This is a fill + matmul case.
%2 = linalg.init_tensor [384, 128] : tensor<384x128xi32>
%3 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<384x128xi32>) -> tensor<384x128xi32>
%4 = linalg.matmul ins(%0, %1 : tensor<384x512xi8>, tensor<512x128xi8>) outs(%3 : tensor<384x128xi32>) -> tensor<384x128xi32>
dispatch_14_matmul_384x128x128
Also a fill + matmul case with different shape.
%2 = linalg.init_tensor [384, 128] : tensor<384x128xi32>
%3 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<384x128xi32>) -> tensor<384x128xi32>
%4 = linalg.matmul ins(%0, %1 : tensor<384x128xi8>, tensor<128x128xi8>) outs(%3 : tensor<384x128xi32>) -> tensor<384x128xi32>
dispatch_33_generic_4x384x384
This is a reduction op on f32.
#map0 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
%1 = linalg.init_tensor [4, 384] : tensor<4x384xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4x384xf32>) -> tensor<4x384xf32>
%3 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%0 : tensor<4x384x384xf32>) outs(%2 : tensor<4x384xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%4 = arith.addf %arg2, %arg3 : f32
linalg.yield %4 : f32
} -> tensor<4x384xf32>
dispatch_6_matmul_384x512x384
fill + quantized matmul + elementwise
%6 = linalg.init_tensor [384, 512] : tensor<384x512xi8>
%7 = linalg.init_tensor [384, 512] : tensor<384x512xi32>
%8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<384x512xi32>) -> tensor<384x512xi32>
%9 = linalg.matmul ins(%0, %1 : tensor<384x384xi8>, tensor<384x512xi8>) outs(%8 : tensor<384x512xi32>) -> tensor<384x512xi32>
dispatch_34_generic_4x384x1x384
quantized elementwise op
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map12 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
%2 = linalg.init_tensor [4, 384, 1, 384] : tensor<4x384x1x384xi8>
%3 = linalg.generic {indexing_maps = [#map11, #map12, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%0, %1 : tensor<4x384x1x384xf32>, tensor<4x384xf32>)
outs(%2 : tensor<4x384x1x384xi8>)
^bb0(%arg3: f32, %arg4: f32, %arg5: i8):
%4 = arith.divf %cst, %arg4 : f32
%5 = arith.mulf %arg3, %4 : f32
%6 = arith.mulf %5, %cst_0 : f32
%7 = arith.addf %6, %cst_1 : f32
%8 = arith.addf %7, %cst_2 : f32
%9 = arith.subf %7, %cst_2 : f32
%10 = arith.cmpf olt, %7, %cst_3 : f32
%11 = arith.select %10, %9, %8 : f32
%12 = arith.minf %11, %cst_4 : f32
%13 = arith.maxf %12, %cst_1 : f32
%14 = arith.fptosi %13 : f32 to i8
linalg.yield %14 : i8
} -> tensor<4x384x1x384xi8>