iree
iree copied to clipboard
[CPU] i4 pack op fails to compile
It looks like some memref.subviews
are not optimized away for this i4 pack op and we try to apply narrow type emulation to it:
#config = #iree_codegen.lowering_config<tile_sizes = [[20000, 16000], [1, 1]]>
#executable_target_system_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "system-elf-arm_64", {cpu = "", cpu_features = "+neon", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", link_embedded = false, native_vector_size = 16 : index, target_triple = "aarch64-none-linux-android34", ukernels = "none"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
#translation = #iree_codegen.translation_info<CPUDataTiling>
module {
hal.executable public @pack_i4 {
hal.executable.variant public @system_elf_arm_64 target(#executable_target_system_elf_arm_64_) {
hal.executable.export public @pack_i4 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @pack_i4() {
%c0_i4 = arith.constant 0 : i4
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16000x32000xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<200000x16000x64x1xi4>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16000, 32000], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16000x32000xi4>> -> tensor<16000x32000xi4>
%3 = tensor.empty() : tensor<200000x16000x64x1xi4>
%pack = tensor.pack %2 padding_value(%c0_i4 : i4) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [64, 1] into %3 {lowering_config = #config} : tensor<16000x32000xi4> -> tensor<200000x16000x64x1xi4>
flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [200000, 16000, 64, 1], strides = [1, 1, 1, 1] : tensor<200000x16000x64x1xi4> -> !flow.dispatch.tensor<writeonly:tensor<200000x16000x64x1xi4>>
return
}
}
}
}
}
Error:
iree-compile --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features="+neon" --iree-llvmcpu-target-triple=aarch64-none-linux-android34 --iree-opt-dcata-tiling=true --iree-llvmcpu-enable-ukernels=none --compile-from=executable-sources repro.mlir
repro.mlir:21:19: error: failed to legalize operation 'memref.subview' that was explicitly marked illegal
%pack = tensor.pack %2 padding_value(%c0_i4 : i4) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [64, 1] into %3 {lowering_config = #config} : tensor<16000x3
2000xi4> -> tensor<200000x16000x64x1xi4>
^
repro.mlir:21:19: note: see current operation: %29 = "memref.subview"(%7, %27, %28, %24) <{operandSegmentSizes = array<i32: 1, 2, 1, 0>, static_offsets = array<i64: -92233720368547758
08, -9223372036854775808>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (memref<16000x32000xi4>, index, index, index) -> memref<1x?xi4, strided<[32000, 1], offset: ?>>
IR before the compilation error:
// -----// IR Dump After FoldMemRefAliasOps (fold-memref-alias-ops) //----- //
module {
func.func @pack_i4() {
%c1 = arith.constant 1 : index
%c20000 = arith.constant 20000 : index
%c16000 = arith.constant 16000 : index
%c200000 = arith.constant 200000 : index
%c0_i4 = arith.constant 0 : i4
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<16000x32000xi4>
memref.assume_alignment %0, 64 : memref<16000x32000xi4>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<200000x16000x64x1xi4>
memref.assume_alignment %1, 64 : memref<200000x16000x64x1xi4>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%2 = affine.apply affine_map<()[s0] -> (s0 * 20000)>()[%workgroup_id_y]
%3 = affine.apply affine_map<()[s0] -> (s0 * 20000)>()[%workgroup_count_y]
%4 = affine.apply affine_map<()[s0] -> (s0 * 16000)>()[%workgroup_id_x]
%5 = affine.apply affine_map<()[s0] -> (s0 * 16000)>()[%workgroup_count_x]
cf.br ^bb1(%2 : index)
^bb1(%6: index): // 2 preds: ^bb0, ^bb11
%7 = arith.cmpi slt, %6, %c200000 : index
cf.cond_br %7, ^bb2, ^bb12
^bb2: // pred: ^bb1
cf.br ^bb3(%4 : index)
^bb3(%8: index): // 2 preds: ^bb2, ^bb10
%9 = arith.cmpi slt, %8, %c16000 : index
cf.cond_br %9, ^bb4, ^bb11
^bb4: // pred: ^bb3
cf.br ^bb5(%c0 : index)
^bb5(%10: index): // 2 preds: ^bb4, ^bb9
%11 = arith.cmpi slt, %10, %c20000 : index
cf.cond_br %11, ^bb6, ^bb10
^bb6: // pred: ^bb5
%12 = affine.min affine_map<()[s0, s1] -> (s0 * -64 - s1 * 64 + 32000, 64)>()[%10, %6]
cf.br ^bb7(%c0 : index)
^bb7(%13: index): // 2 preds: ^bb6, ^bb8
%14 = arith.cmpi slt, %13, %c16000 : index
cf.cond_br %14, ^bb8, ^bb9
^bb8: // pred: ^bb7
%15 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%8, %13]
%16 = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1 * 64)>()[%6, %10]
%subview = memref.subview %0[%15, %16] [1, %12] [1, 1] : memref<16000x32000xi4> to memref<1x?xi4, strided<[32000, 1], offset: ?>>
%17 = vector.transfer_read %subview[%c0, %c0], %c0_i4 : memref<1x?xi4, strided<[32000, 1], offset: ?>>, vector<64xi4>
%18 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%6, %10]
%19 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%8, %13]
vector.store %17, %1[%18, %19, %c0, %c0] : memref<200000x16000x64x1xi4>, vector<64xi4>
%20 = arith.addi %13, %c1 : index
cf.br ^bb7(%20 : index)
^bb9: // pred: ^bb7
%21 = arith.addi %10, %c1 : index
cf.br ^bb5(%21 : index)
^bb10: // pred: ^bb5
%22 = arith.addi %8, %5 : index
cf.br ^bb3(%22 : index)
^bb11: // pred: ^bb3
%23 = arith.addi %6, %3 : index
cf.br ^bb1(%23 : index)
^bb12: // pred: ^bb1
return
}
}
@hanhanW, @MaheshRavishankar this is the memref.subview
error we talked about. Also pinging @Max191. This is currently blocking our DT+UK enablement for i4.
Dont we have patterns to fold memref.subview
into vector.transfer_read
? It is here https://github.com/shark-infra/llvm-project/blob/6a22c340976abaa0a65f580f8e83fd8ea1593b95/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp#L379-L380 That subview shouldnt be there after fold memref subviews.
Is there a reason not to fold when in_bounds
is false
? https://github.com/shark-infra/llvm-project/blob/6a22c340976abaa0a65f580f8e83fd8ea1593b95/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp#L351
I'm trying to understand what would be needed to support this case...
Is there a reason not to fold when
in_bounds
isfalse
? https://github.com/shark-infra/llvm-project/blob/6a22c340976abaa0a65f580f8e83fd8ea1593b95/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp#L351I'm trying to understand what would be needed to support this case...
I guess not, maybe being too-conservative. We can probably just set in_bounds to false on the resulting op, and that still preserves semantics.
Thanks! I gave that a try: https://github.com/llvm/llvm-project/pull/80517 It fixes this issue.
Hitting similar error in the narrow type emulation pass, this time involving a memref.subview
+ memref.colapse_shape
+ vector.load
(no out of bounds this time) that I don't think we can/should fold into the vector load :
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1] : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
%collapse_shape_5 = memref.collapse_shape %subview_4 [[0], [1], [2, 3]] : memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> into memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>
%41 = vector.load %collapse_shape_5[%19, %21, %c0] : memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>, vector<64xi4>
I haven't thought too much about this yet but it looks like we may want to add support for memref.subview
and memref.collapse_shape
to NarrowTypeEmulation regardless of them being fused or not in some cases. I would appreciate some brainstorming @Max191, @hanhanW, @MaheshRavishankar.
Full IR before NarrowTypeEmulation:
module {
func.func @main_dispatch_65_mmt4d_64x256x512x8x8x8_i8xi4xi32() {
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c3 = arith.constant 3 : index
%cst = arith.constant dense<0> : vector<64xi32>
%cst_0 = arith.constant dense<0> : vector<8x8xi4>
%cst_1 = arith.constant dense<0> : vector<8x8xi8>
%cst_2 = arith.constant dense<0> : vector<8x8xi32>
%cst_3 = arith.constant dense<0> : vector<1x1x8x8xi32>
%c512 = arith.constant 512 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c64 = arith.constant 64 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = arith.index_castui %0 : i32 to index
%4 = arith.index_castui %1 : i32 to index
%5 = arith.index_castui %2 : i32 to index
%6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%3) flags(ReadOnly) : memref<64x512x8x8xi8, strided<[32768, 64, 8, 1], offset: ?>>
memref.assume_alignment %6, 1 : memref<64x512x8x8xi8, strided<[32768, 64, 8, 1], offset: ?>>
%7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%4) flags(ReadOnly) : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
memref.assume_alignment %7, 1 : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
%8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%5) : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>
memref.assume_alignment %8, 1 : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%9 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_id_y]
%10 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_count_y]
%11 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%12 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_count_x]
cf.br ^bb1(%9 : index)
^bb1(%13: index): // 2 preds: ^bb0, ^bb12
%14 = arith.cmpi slt, %13, %c64 : index
cf.cond_br %14, ^bb2, ^bb13
^bb2: // pred: ^bb1
%subview = memref.subview %6[%13, 0, 0, 0] [2, 512, 8, 8] [1, 1, 1, 1] : memref<64x512x8x8xi8, strided<[32768, 64, 8, 1], offset: ?>> to memref<2x512x8x8xi8, strided<[32768, 64, 8, 1], offset: ?>>
cf.br ^bb3(%11 : index)
^bb3(%15: index): // 2 preds: ^bb2, ^bb11
%16 = arith.cmpi slt, %15, %c256 : index
cf.cond_br %16, ^bb4, ^bb12
^bb4: // pred: ^bb3
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1] : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
cf.br ^bb5(%c0 : index)
^bb5(%17: index): // 2 preds: ^bb4, ^bb10
%18 = arith.cmpi slt, %17, %c2 : index
cf.cond_br %18, ^bb6(%c0 : index), ^bb11
^bb6(%19: index): // 2 preds: ^bb5, ^bb9
%20 = arith.cmpi slt, %19, %c4 : index
cf.cond_br %20, ^bb7(%c0, %cst_3 : index, vector<1x1x8x8xi32>), ^bb10
^bb7(%21: index, %22: vector<1x1x8x8xi32>): // 2 preds: ^bb6, ^bb8
%23 = arith.cmpi slt, %21, %c512 : index
cf.cond_br %23, ^bb8, ^bb9
^bb8: // pred: ^bb7
%collapse_shape = memref.collapse_shape %subview [[0], [1], [2, 3]] : memref<2x512x8x8xi8, strided<[32768, 64, 8, 1], offset: ?>> into memref<2x512x64xi8, strided<[32768, 64, 1], offset: ?>>
%24 = vector.load %collapse_shape[%17, %21, %c0] : memref<2x512x64xi8, strided<[32768, 64, 1], offset: ?>>, vector<64xi8>
%25 = vector.extract_strided_slice %24 {offsets = [0], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%26 = vector.insert %25, %cst_1 [0] : vector<8xi8> into vector<8x8xi8>
%27 = vector.extract_strided_slice %24 {offsets = [8], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%28 = vector.insert %27, %26 [1] : vector<8xi8> into vector<8x8xi8>
%29 = vector.extract_strided_slice %24 {offsets = [16], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%30 = vector.insert %29, %28 [2] : vector<8xi8> into vector<8x8xi8>
%31 = vector.extract_strided_slice %24 {offsets = [24], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%32 = vector.insert %31, %30 [3] : vector<8xi8> into vector<8x8xi8>
%33 = vector.extract_strided_slice %24 {offsets = [32], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%34 = vector.insert %33, %32 [4] : vector<8xi8> into vector<8x8xi8>
%35 = vector.extract_strided_slice %24 {offsets = [40], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%36 = vector.insert %35, %34 [5] : vector<8xi8> into vector<8x8xi8>
%37 = vector.extract_strided_slice %24 {offsets = [48], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%38 = vector.insert %37, %36 [6] : vector<8xi8> into vector<8x8xi8>
%39 = vector.extract_strided_slice %24 {offsets = [56], sizes = [8], strides = [1]} : vector<64xi8> to vector<8xi8>
%40 = vector.insert %39, %38 [7] : vector<8xi8> into vector<8x8xi8>
%collapse_shape_5 = memref.collapse_shape %subview_4 [[0], [1], [2, 3]] : memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> into memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>
%41 = vector.load %collapse_shape_5[%19, %21, %c0] : memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>, vector<64xi4>
%42 = vector.extract_strided_slice %41 {offsets = [0], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%43 = vector.insert %42, %cst_0 [0] : vector<8xi4> into vector<8x8xi4>
%44 = vector.extract_strided_slice %41 {offsets = [8], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%45 = vector.insert %44, %43 [1] : vector<8xi4> into vector<8x8xi4>
%46 = vector.extract_strided_slice %41 {offsets = [16], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%47 = vector.insert %46, %45 [2] : vector<8xi4> into vector<8x8xi4>
%48 = vector.extract_strided_slice %41 {offsets = [24], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%49 = vector.insert %48, %47 [3] : vector<8xi4> into vector<8x8xi4>
%50 = vector.extract_strided_slice %41 {offsets = [32], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%51 = vector.insert %50, %49 [4] : vector<8xi4> into vector<8x8xi4>
%52 = vector.extract_strided_slice %41 {offsets = [40], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%53 = vector.insert %52, %51 [5] : vector<8xi4> into vector<8x8xi4>
%54 = vector.extract_strided_slice %41 {offsets = [48], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%55 = vector.insert %54, %53 [6] : vector<8xi4> into vector<8x8xi4>
%56 = vector.extract_strided_slice %41 {offsets = [56], sizes = [8], strides = [1]} : vector<64xi4> to vector<8xi4>
%57 = vector.insert %56, %55 [7] : vector<8xi4> into vector<8x8xi4>
%58 = arith.extsi %40 : vector<8x8xi8> to vector<8x8xi32>
%59 = arith.extsi %57 : vector<8x8xi4> to vector<8x8xi32>
%60 = vector.extract %59[0] : vector<8xi32> from vector<8x8xi32>
%61 = vector.insert_strided_slice %60, %cst {offsets = [0], strides = [1]} : vector<8xi32> into vector<64xi32>
%62 = vector.extract %59[1] : vector<8xi32> from vector<8x8xi32>
%63 = vector.insert_strided_slice %62, %61 {offsets = [8], strides = [1]} : vector<8xi32> into vector<64xi32>
%64 = vector.extract %59[2] : vector<8xi32> from vector<8x8xi32>
%65 = vector.insert_strided_slice %64, %63 {offsets = [16], strides = [1]} : vector<8xi32> into vector<64xi32>
%66 = vector.extract %59[3] : vector<8xi32> from vector<8x8xi32>
%67 = vector.insert_strided_slice %66, %65 {offsets = [24], strides = [1]} : vector<8xi32> into vector<64xi32>
%68 = vector.extract %59[4] : vector<8xi32> from vector<8x8xi32>
%69 = vector.insert_strided_slice %68, %67 {offsets = [32], strides = [1]} : vector<8xi32> into vector<64xi32>
%70 = vector.extract %59[5] : vector<8xi32> from vector<8x8xi32>
%71 = vector.insert_strided_slice %70, %69 {offsets = [40], strides = [1]} : vector<8xi32> into vector<64xi32>
%72 = vector.extract %59[6] : vector<8xi32> from vector<8x8xi32>
%73 = vector.insert_strided_slice %72, %71 {offsets = [48], strides = [1]} : vector<8xi32> into vector<64xi32>
%74 = vector.extract %59[7] : vector<8xi32> from vector<8x8xi32>
%75 = vector.insert_strided_slice %74, %73 {offsets = [56], strides = [1]} : vector<8xi32> into vector<64xi32>
%76 = vector.shuffle %75, %75 [0, 8, 16, 24, 32, 40, 48, 56, 1, 9, 17, 25, 33, 41, 49, 57, 2, 10, 18, 26, 34, 42, 50, 58, 3, 11, 19, 27, 35, 43, 51, 59, 4, 12, 20, 28, 36, 44, 52, 60, 5, 13, 21, 29, 37, 45, 53, 61, 6, 14, 22, 30, 38, 46, 54, 62, 7, 15, 23, 31, 39, 47, 55, 63] : vector<64xi32>, vector<64xi32>
%77 = vector.extract_strided_slice %76 {offsets = [0], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%78 = vector.extract_strided_slice %76 {offsets = [8], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%79 = vector.extract_strided_slice %76 {offsets = [16], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%80 = vector.extract_strided_slice %76 {offsets = [24], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%81 = vector.extract_strided_slice %76 {offsets = [32], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%82 = vector.extract_strided_slice %76 {offsets = [40], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%83 = vector.extract_strided_slice %76 {offsets = [48], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%84 = vector.extract_strided_slice %76 {offsets = [56], sizes = [8], strides = [1]} : vector<64xi32> to vector<8xi32>
%85 = vector.extract %58[0, 0] : i32 from vector<8x8xi32>
%86 = vector.broadcast %85 : i32 to vector<8xi32>
%87 = vector.extract %22[0, 0, 0] : vector<8xi32> from vector<1x1x8x8xi32>
%88 = arith.muli %86, %77 : vector<8xi32>
%89 = arith.addi %88, %87 : vector<8xi32>
%90 = vector.extract %58[1, 0] : i32 from vector<8x8xi32>
%91 = vector.broadcast %90 : i32 to vector<8xi32>
%92 = vector.extract %22[0, 0, 1] : vector<8xi32> from vector<1x1x8x8xi32>
%93 = arith.muli %91, %77 : vector<8xi32>
%94 = arith.addi %93, %92 : vector<8xi32>
%95 = vector.extract %58[2, 0] : i32 from vector<8x8xi32>
%96 = vector.broadcast %95 : i32 to vector<8xi32>
%97 = vector.extract %22[0, 0, 2] : vector<8xi32> from vector<1x1x8x8xi32>
%98 = arith.muli %96, %77 : vector<8xi32>
%99 = arith.addi %98, %97 : vector<8xi32>
%100 = vector.extract %58[3, 0] : i32 from vector<8x8xi32>
%101 = vector.broadcast %100 : i32 to vector<8xi32>
%102 = vector.extract %22[0, 0, 3] : vector<8xi32> from vector<1x1x8x8xi32>
%103 = arith.muli %101, %77 : vector<8xi32>
%104 = arith.addi %103, %102 : vector<8xi32>
%105 = vector.extract %58[4, 0] : i32 from vector<8x8xi32>
%106 = vector.broadcast %105 : i32 to vector<8xi32>
%107 = vector.extract %22[0, 0, 4] : vector<8xi32> from vector<1x1x8x8xi32>
%108 = arith.muli %106, %77 : vector<8xi32>
%109 = arith.addi %108, %107 : vector<8xi32>
%110 = vector.extract %58[5, 0] : i32 from vector<8x8xi32>
%111 = vector.broadcast %110 : i32 to vector<8xi32>
%112 = vector.extract %22[0, 0, 5] : vector<8xi32> from vector<1x1x8x8xi32>
%113 = arith.muli %111, %77 : vector<8xi32>
%114 = arith.addi %113, %112 : vector<8xi32>
%115 = vector.extract %58[6, 0] : i32 from vector<8x8xi32>
%116 = vector.broadcast %115 : i32 to vector<8xi32>
%117 = vector.extract %22[0, 0, 6] : vector<8xi32> from vector<1x1x8x8xi32>
%118 = arith.muli %116, %77 : vector<8xi32>
%119 = arith.addi %118, %117 : vector<8xi32>
%120 = vector.extract %58[7, 0] : i32 from vector<8x8xi32>
%121 = vector.broadcast %120 : i32 to vector<8xi32>
%122 = vector.extract %22[0, 0, 7] : vector<8xi32> from vector<1x1x8x8xi32>
%123 = arith.muli %121, %77 : vector<8xi32>
%124 = arith.addi %123, %122 : vector<8xi32>
%125 = vector.extract %58[0, 1] : i32 from vector<8x8xi32>
%126 = vector.broadcast %125 : i32 to vector<8xi32>
%127 = arith.muli %126, %78 : vector<8xi32>
%128 = arith.addi %127, %89 : vector<8xi32>
%129 = vector.extract %58[1, 1] : i32 from vector<8x8xi32>
%130 = vector.broadcast %129 : i32 to vector<8xi32>
%131 = arith.muli %130, %78 : vector<8xi32>
%132 = arith.addi %131, %94 : vector<8xi32>
%133 = vector.extract %58[2, 1] : i32 from vector<8x8xi32>
%134 = vector.broadcast %133 : i32 to vector<8xi32>
%135 = arith.muli %134, %78 : vector<8xi32>
%136 = arith.addi %135, %99 : vector<8xi32>
%137 = vector.extract %58[3, 1] : i32 from vector<8x8xi32>
%138 = vector.broadcast %137 : i32 to vector<8xi32>
%139 = arith.muli %138, %78 : vector<8xi32>
%140 = arith.addi %139, %104 : vector<8xi32>
%141 = vector.extract %58[4, 1] : i32 from vector<8x8xi32>
%142 = vector.broadcast %141 : i32 to vector<8xi32>
%143 = arith.muli %142, %78 : vector<8xi32>
%144 = arith.addi %143, %109 : vector<8xi32>
%145 = vector.extract %58[5, 1] : i32 from vector<8x8xi32>
%146 = vector.broadcast %145 : i32 to vector<8xi32>
%147 = arith.muli %146, %78 : vector<8xi32>
%148 = arith.addi %147, %114 : vector<8xi32>
%149 = vector.extract %58[6, 1] : i32 from vector<8x8xi32>
%150 = vector.broadcast %149 : i32 to vector<8xi32>
%151 = arith.muli %150, %78 : vector<8xi32>
%152 = arith.addi %151, %119 : vector<8xi32>
%153 = vector.extract %58[7, 1] : i32 from vector<8x8xi32>
%154 = vector.broadcast %153 : i32 to vector<8xi32>
%155 = arith.muli %154, %78 : vector<8xi32>
%156 = arith.addi %155, %124 : vector<8xi32>
%157 = vector.extract %58[0, 2] : i32 from vector<8x8xi32>
%158 = vector.broadcast %157 : i32 to vector<8xi32>
%159 = arith.muli %158, %79 : vector<8xi32>
%160 = arith.addi %159, %128 : vector<8xi32>
%161 = vector.extract %58[1, 2] : i32 from vector<8x8xi32>
%162 = vector.broadcast %161 : i32 to vector<8xi32>
%163 = arith.muli %162, %79 : vector<8xi32>
%164 = arith.addi %163, %132 : vector<8xi32>
%165 = vector.extract %58[2, 2] : i32 from vector<8x8xi32>
%166 = vector.broadcast %165 : i32 to vector<8xi32>
%167 = arith.muli %166, %79 : vector<8xi32>
%168 = arith.addi %167, %136 : vector<8xi32>
%169 = vector.extract %58[3, 2] : i32 from vector<8x8xi32>
%170 = vector.broadcast %169 : i32 to vector<8xi32>
%171 = arith.muli %170, %79 : vector<8xi32>
%172 = arith.addi %171, %140 : vector<8xi32>
%173 = vector.extract %58[4, 2] : i32 from vector<8x8xi32>
%174 = vector.broadcast %173 : i32 to vector<8xi32>
%175 = arith.muli %174, %79 : vector<8xi32>
%176 = arith.addi %175, %144 : vector<8xi32>
%177 = vector.extract %58[5, 2] : i32 from vector<8x8xi32>
%178 = vector.broadcast %177 : i32 to vector<8xi32>
%179 = arith.muli %178, %79 : vector<8xi32>
%180 = arith.addi %179, %148 : vector<8xi32>
%181 = vector.extract %58[6, 2] : i32 from vector<8x8xi32>
%182 = vector.broadcast %181 : i32 to vector<8xi32>
%183 = arith.muli %182, %79 : vector<8xi32>
%184 = arith.addi %183, %152 : vector<8xi32>
%185 = vector.extract %58[7, 2] : i32 from vector<8x8xi32>
%186 = vector.broadcast %185 : i32 to vector<8xi32>
%187 = arith.muli %186, %79 : vector<8xi32>
%188 = arith.addi %187, %156 : vector<8xi32>
%189 = vector.extract %58[0, 3] : i32 from vector<8x8xi32>
%190 = vector.broadcast %189 : i32 to vector<8xi32>
%191 = arith.muli %190, %80 : vector<8xi32>
%192 = arith.addi %191, %160 : vector<8xi32>
%193 = vector.extract %58[1, 3] : i32 from vector<8x8xi32>
%194 = vector.broadcast %193 : i32 to vector<8xi32>
%195 = arith.muli %194, %80 : vector<8xi32>
%196 = arith.addi %195, %164 : vector<8xi32>
%197 = vector.extract %58[2, 3] : i32 from vector<8x8xi32>
%198 = vector.broadcast %197 : i32 to vector<8xi32>
%199 = arith.muli %198, %80 : vector<8xi32>
%200 = arith.addi %199, %168 : vector<8xi32>
%201 = vector.extract %58[3, 3] : i32 from vector<8x8xi32>
%202 = vector.broadcast %201 : i32 to vector<8xi32>
%203 = arith.muli %202, %80 : vector<8xi32>
%204 = arith.addi %203, %172 : vector<8xi32>
%205 = vector.extract %58[4, 3] : i32 from vector<8x8xi32>
%206 = vector.broadcast %205 : i32 to vector<8xi32>
%207 = arith.muli %206, %80 : vector<8xi32>
%208 = arith.addi %207, %176 : vector<8xi32>
%209 = vector.extract %58[5, 3] : i32 from vector<8x8xi32>
%210 = vector.broadcast %209 : i32 to vector<8xi32>
%211 = arith.muli %210, %80 : vector<8xi32>
%212 = arith.addi %211, %180 : vector<8xi32>
%213 = vector.extract %58[6, 3] : i32 from vector<8x8xi32>
%214 = vector.broadcast %213 : i32 to vector<8xi32>
%215 = arith.muli %214, %80 : vector<8xi32>
%216 = arith.addi %215, %184 : vector<8xi32>
%217 = vector.extract %58[7, 3] : i32 from vector<8x8xi32>
%218 = vector.broadcast %217 : i32 to vector<8xi32>
%219 = arith.muli %218, %80 : vector<8xi32>
%220 = arith.addi %219, %188 : vector<8xi32>
%221 = vector.extract %58[0, 4] : i32 from vector<8x8xi32>
%222 = vector.broadcast %221 : i32 to vector<8xi32>
%223 = arith.muli %222, %81 : vector<8xi32>
%224 = arith.addi %223, %192 : vector<8xi32>
%225 = vector.extract %58[1, 4] : i32 from vector<8x8xi32>
%226 = vector.broadcast %225 : i32 to vector<8xi32>
%227 = arith.muli %226, %81 : vector<8xi32>
%228 = arith.addi %227, %196 : vector<8xi32>
%229 = vector.extract %58[2, 4] : i32 from vector<8x8xi32>
%230 = vector.broadcast %229 : i32 to vector<8xi32>
%231 = arith.muli %230, %81 : vector<8xi32>
%232 = arith.addi %231, %200 : vector<8xi32>
%233 = vector.extract %58[3, 4] : i32 from vector<8x8xi32>
%234 = vector.broadcast %233 : i32 to vector<8xi32>
%235 = arith.muli %234, %81 : vector<8xi32>
%236 = arith.addi %235, %204 : vector<8xi32>
%237 = vector.extract %58[4, 4] : i32 from vector<8x8xi32>
%238 = vector.broadcast %237 : i32 to vector<8xi32>
%239 = arith.muli %238, %81 : vector<8xi32>
%240 = arith.addi %239, %208 : vector<8xi32>
%241 = vector.extract %58[5, 4] : i32 from vector<8x8xi32>
%242 = vector.broadcast %241 : i32 to vector<8xi32>
%243 = arith.muli %242, %81 : vector<8xi32>
%244 = arith.addi %243, %212 : vector<8xi32>
%245 = vector.extract %58[6, 4] : i32 from vector<8x8xi32>
%246 = vector.broadcast %245 : i32 to vector<8xi32>
%247 = arith.muli %246, %81 : vector<8xi32>
%248 = arith.addi %247, %216 : vector<8xi32>
%249 = vector.extract %58[7, 4] : i32 from vector<8x8xi32>
%250 = vector.broadcast %249 : i32 to vector<8xi32>
%251 = arith.muli %250, %81 : vector<8xi32>
%252 = arith.addi %251, %220 : vector<8xi32>
%253 = vector.extract %58[0, 5] : i32 from vector<8x8xi32>
%254 = vector.broadcast %253 : i32 to vector<8xi32>
%255 = arith.muli %254, %82 : vector<8xi32>
%256 = arith.addi %255, %224 : vector<8xi32>
%257 = vector.extract %58[1, 5] : i32 from vector<8x8xi32>
%258 = vector.broadcast %257 : i32 to vector<8xi32>
%259 = arith.muli %258, %82 : vector<8xi32>
%260 = arith.addi %259, %228 : vector<8xi32>
%261 = vector.extract %58[2, 5] : i32 from vector<8x8xi32>
%262 = vector.broadcast %261 : i32 to vector<8xi32>
%263 = arith.muli %262, %82 : vector<8xi32>
%264 = arith.addi %263, %232 : vector<8xi32>
%265 = vector.extract %58[3, 5] : i32 from vector<8x8xi32>
%266 = vector.broadcast %265 : i32 to vector<8xi32>
%267 = arith.muli %266, %82 : vector<8xi32>
%268 = arith.addi %267, %236 : vector<8xi32>
%269 = vector.extract %58[4, 5] : i32 from vector<8x8xi32>
%270 = vector.broadcast %269 : i32 to vector<8xi32>
%271 = arith.muli %270, %82 : vector<8xi32>
%272 = arith.addi %271, %240 : vector<8xi32>
%273 = vector.extract %58[5, 5] : i32 from vector<8x8xi32>
%274 = vector.broadcast %273 : i32 to vector<8xi32>
%275 = arith.muli %274, %82 : vector<8xi32>
%276 = arith.addi %275, %244 : vector<8xi32>
%277 = vector.extract %58[6, 5] : i32 from vector<8x8xi32>
%278 = vector.broadcast %277 : i32 to vector<8xi32>
%279 = arith.muli %278, %82 : vector<8xi32>
%280 = arith.addi %279, %248 : vector<8xi32>
%281 = vector.extract %58[7, 5] : i32 from vector<8x8xi32>
%282 = vector.broadcast %281 : i32 to vector<8xi32>
%283 = arith.muli %282, %82 : vector<8xi32>
%284 = arith.addi %283, %252 : vector<8xi32>
%285 = vector.extract %58[0, 6] : i32 from vector<8x8xi32>
%286 = vector.broadcast %285 : i32 to vector<8xi32>
%287 = arith.muli %286, %83 : vector<8xi32>
%288 = arith.addi %287, %256 : vector<8xi32>
%289 = vector.extract %58[1, 6] : i32 from vector<8x8xi32>
%290 = vector.broadcast %289 : i32 to vector<8xi32>
%291 = arith.muli %290, %83 : vector<8xi32>
%292 = arith.addi %291, %260 : vector<8xi32>
%293 = vector.extract %58[2, 6] : i32 from vector<8x8xi32>
%294 = vector.broadcast %293 : i32 to vector<8xi32>
%295 = arith.muli %294, %83 : vector<8xi32>
%296 = arith.addi %295, %264 : vector<8xi32>
%297 = vector.extract %58[3, 6] : i32 from vector<8x8xi32>
%298 = vector.broadcast %297 : i32 to vector<8xi32>
%299 = arith.muli %298, %83 : vector<8xi32>
%300 = arith.addi %299, %268 : vector<8xi32>
%301 = vector.extract %58[4, 6] : i32 from vector<8x8xi32>
%302 = vector.broadcast %301 : i32 to vector<8xi32>
%303 = arith.muli %302, %83 : vector<8xi32>
%304 = arith.addi %303, %272 : vector<8xi32>
%305 = vector.extract %58[5, 6] : i32 from vector<8x8xi32>
%306 = vector.broadcast %305 : i32 to vector<8xi32>
%307 = arith.muli %306, %83 : vector<8xi32>
%308 = arith.addi %307, %276 : vector<8xi32>
%309 = vector.extract %58[6, 6] : i32 from vector<8x8xi32>
%310 = vector.broadcast %309 : i32 to vector<8xi32>
%311 = arith.muli %310, %83 : vector<8xi32>
%312 = arith.addi %311, %280 : vector<8xi32>
%313 = vector.extract %58[7, 6] : i32 from vector<8x8xi32>
%314 = vector.broadcast %313 : i32 to vector<8xi32>
%315 = arith.muli %314, %83 : vector<8xi32>
%316 = arith.addi %315, %284 : vector<8xi32>
%317 = vector.extract %58[0, 7] : i32 from vector<8x8xi32>
%318 = vector.broadcast %317 : i32 to vector<8xi32>
%319 = arith.muli %318, %84 : vector<8xi32>
%320 = arith.addi %319, %288 : vector<8xi32>
%321 = vector.insert %320, %cst_2 [0] : vector<8xi32> into vector<8x8xi32>
%322 = vector.extract %58[1, 7] : i32 from vector<8x8xi32>
%323 = vector.broadcast %322 : i32 to vector<8xi32>
%324 = arith.muli %323, %84 : vector<8xi32>
%325 = arith.addi %324, %292 : vector<8xi32>
%326 = vector.insert %325, %321 [1] : vector<8xi32> into vector<8x8xi32>
%327 = vector.extract %58[2, 7] : i32 from vector<8x8xi32>
%328 = vector.broadcast %327 : i32 to vector<8xi32>
%329 = arith.muli %328, %84 : vector<8xi32>
%330 = arith.addi %329, %296 : vector<8xi32>
%331 = vector.insert %330, %326 [2] : vector<8xi32> into vector<8x8xi32>
%332 = vector.extract %58[3, 7] : i32 from vector<8x8xi32>
%333 = vector.broadcast %332 : i32 to vector<8xi32>
%334 = arith.muli %333, %84 : vector<8xi32>
%335 = arith.addi %334, %300 : vector<8xi32>
%336 = vector.insert %335, %331 [3] : vector<8xi32> into vector<8x8xi32>
%337 = vector.extract %58[4, 7] : i32 from vector<8x8xi32>
%338 = vector.broadcast %337 : i32 to vector<8xi32>
%339 = arith.muli %338, %84 : vector<8xi32>
%340 = arith.addi %339, %304 : vector<8xi32>
%341 = vector.insert %340, %336 [4] : vector<8xi32> into vector<8x8xi32>
%342 = vector.extract %58[5, 7] : i32 from vector<8x8xi32>
%343 = vector.broadcast %342 : i32 to vector<8xi32>
%344 = arith.muli %343, %84 : vector<8xi32>
%345 = arith.addi %344, %308 : vector<8xi32>
%346 = vector.insert %345, %341 [5] : vector<8xi32> into vector<8x8xi32>
%347 = vector.extract %58[6, 7] : i32 from vector<8x8xi32>
%348 = vector.broadcast %347 : i32 to vector<8xi32>
%349 = arith.muli %348, %84 : vector<8xi32>
%350 = arith.addi %349, %312 : vector<8xi32>
%351 = vector.insert %350, %346 [6] : vector<8xi32> into vector<8x8xi32>
%352 = vector.extract %58[7, 7] : i32 from vector<8x8xi32>
%353 = vector.broadcast %352 : i32 to vector<8xi32>
%354 = arith.muli %353, %84 : vector<8xi32>
%355 = arith.addi %354, %316 : vector<8xi32>
%356 = vector.insert %355, %351 [7] : vector<8xi32> into vector<8x8xi32>
%357 = vector.broadcast %356 : vector<8x8xi32> to vector<1x1x8x8xi32>
%358 = arith.addi %21, %c1 : index
cf.br ^bb7(%358, %357 : index, vector<1x1x8x8xi32>)
^bb9: // pred: ^bb7
%359 = vector.extract %22[0, 0, 0] : vector<8xi32> from vector<1x1x8x8xi32>
%360 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%361 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %359, %8[%360, %361, %c0, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%362 = vector.extract %22[0, 0, 1] : vector<8xi32> from vector<1x1x8x8xi32>
%363 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%364 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %362, %8[%363, %364, %c1, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%365 = vector.extract %22[0, 0, 2] : vector<8xi32> from vector<1x1x8x8xi32>
%366 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%367 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %365, %8[%366, %367, %c2, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%368 = vector.extract %22[0, 0, 3] : vector<8xi32> from vector<1x1x8x8xi32>
%369 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%370 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %368, %8[%369, %370, %c3, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%371 = vector.extract %22[0, 0, 4] : vector<8xi32> from vector<1x1x8x8xi32>
%372 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%373 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %371, %8[%372, %373, %c4, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%374 = vector.extract %22[0, 0, 5] : vector<8xi32> from vector<1x1x8x8xi32>
%375 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%376 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %374, %8[%375, %376, %c5, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%377 = vector.extract %22[0, 0, 6] : vector<8xi32> from vector<1x1x8x8xi32>
%378 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%379 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %377, %8[%378, %379, %c6, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%380 = vector.extract %22[0, 0, 7] : vector<8xi32> from vector<1x1x8x8xi32>
%381 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%13, %17]
%382 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%15, %19]
vector.store %380, %8[%381, %382, %c7, %c0] : memref<64x256x8x8xi32, strided<[16384, 64, 8, 1], offset: ?>>, vector<8xi32>
%383 = arith.addi %19, %c1 : index
cf.br ^bb6(%383 : index)
^bb10: // pred: ^bb6
%384 = arith.addi %17, %c1 : index
cf.br ^bb5(%384 : index)
^bb11: // pred: ^bb5
%385 = arith.addi %15, %12 : index
cf.br ^bb3(%385 : index)
^bb12: // pred: ^bb3
%386 = arith.addi %13, %10 : index
cf.br ^bb1(%386 : index)
^bb13: // pred: ^bb1
return
}
}
The type emulation cannot work for subview
s. You can only do type emulation for memrefs types that can be linearized. (You can't type emulate on multi-dimensional memrefs).
All subviews must be folded into their uses before type emulation. It seems like the collapse shape is blocking the folding here.
I guess we need separate patterns to fold collapse shapes into loads stores. Those must be easy to do
Hmm, I thought there were already some existing patterns that would handle cases like this. Maybe I wrote some a while ago and never upstreamed it :p I can take a closer look on Monday, but I think this collapse shape should be easy to fold.
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1] : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> %collapse_shape_5 = memref.collapse_shape %subview_4 [[0], [1], [2, 3]] : memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> into memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>> %41 = vector.load %collapse_shape_5[%19, %21, %c0] : memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>, vector<64xi4>
Ok, I think we can rewrite this into a vector.transfer_read
on the subview, and then a vector.shape_cast
to collapse the read. Then the subview will fold with the transfer_read, and we will also need to add narrow type emulation for vector.shape_cast
. What do you guys think?
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1] : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> %collapse_shape_5 = memref.collapse_shape %subview_4 [[0], [1], [2, 3]] : memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> into memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>> %41 = vector.load %collapse_shape_5[%19, %21, %c0] : memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>, vector<64xi4>
Ok, I think we can rewrite this into a
vector.transfer_read
on the subview, and then avector.shape_cast
to collapse the read. Then the subview will fold with the transfer_read, and we will also need to add narrow type emulation forvector.shape_cast
. What do you guys think?
Anything that is multi-dimensional will be hard to support for type emulation. Fold the shape_cast as well?
@dcaballe why do we have memref.collapse_shape
? Did you enable vector flatten in your local build?
Ok, I think we can rewrite this into a vector.transfer_read on the subview, and then a vector.shape_cast to collapse the read.
The emulation happens ways after vector lowering. Are you saying to add the pattern in vector lowering? I think the collapse_shape
is intended. Otherwise we will unroll multi-dimensional vector.transfer_read ops, which introduces scalar loads.
The type emulation cannot work for subviews. You can only do type emulation for memrefs types that can be linearized.
Since we are able to collapse the memref, the below memref types seem to be linearizable? Can we teach the emulation to handle subview -> memref.collapse_shape
chain?
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1]
: memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
%collapse_shape_5 = memref.collapse_shape %subview_4 [[0], [1], [2, 3]]
: memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
into memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>
%41 = vector.load %collapse_shape_5[%19, %21, %c0] : memref<4x512x64xi4, strided<[32768, 64, 1], offset: ?>>, vector<64xi4>
Thanks for the quick response, appreciate it!
@dcaballe why do we have memref.collapse_shape? Did you enable vector flatten in your local build?
Yes, I'm experimenting with this here: https://github.com/openxla/iree/pull/16456. You may have also seen some fixed/extensions on that pass. More to come!
Regarding the emulation discussion, I'm not sure I totally understand. Why does everything have to be folded into the transfer read/write ops? Memref subview and collapse ops are only index computation so it should be easy to add emulation for them directly. As we discussed in https://github.com/llvm/llvm-project/pull/80517 , it's unlikely we can fold all the memrefs subviews into their transfer read/write op consumers so we may need direct emulation support for them. Multi dimensional support should be ok as long as the number of elements in the contiguous dimension by their bitwidth is multiple of 1 byte. Am I missing something?
Thanks for the quick response, appreciate it!
@dcaballe why do we have memref.collapse_shape? Did you enable vector flatten in your local build?
Yes, I'm experimenting with this here: #16456. You may have also seen some fixed/extensions on that pass. More to come!
Regarding the emulation discussion, I'm not sure I totally understand. Why does everything have to be folded into the transfer read/write ops? Memref subview and collapse ops are only index computation so it should be easy to add emulation for them directly. As we discussed in llvm/llvm-project#80517 , it's unlikely we can fold all the memrefs subviews into their transfer read/write op consumers so we may need direct emulation support for them. Multi dimensional support should be ok as long as the number of elements in the contiguous dimension by their bitwidth is multiple of 1 byte. Am I missing something?
That is a very restricted case and can only support static shapes. So I dont count them as something that can be "supported". It will force you to walk a very tight path. Moreover I think it is unnecessary. If you cannot fold a memref.subview
with the load/store then there are going to be other issues anyway. So adding type propagation for subviews is treating a symptom and not the cause (and is unnecessarily complicated). https://github.com/llvm/llvm-project/pull/80517 is really pointing a pit fall in vector.transfer_read
semantics. So again thats a symptom, not the root cause.
That is a very restricted case and can only support static shapes. So I dont count them as something that can be "supported". It will force you to walk a very tight path. Moreover I think it is unnecessary. If you cannot fold a
memref.subview
with the load/store then there are going to be other issues anyway. So adding type propagation for subviews is treating a symptom and not the cause (and is unnecessarily complicated). llvm/llvm-project#80517 is really pointing a pit fall invector.transfer_read
semantics. So again thats a symptom, not the root cause.
Probably not a high priority but this is a compilation issue that would impact any Android phone without SVE support so I'm not sure we can just skip it. IIRC, the missing folding for this case was due to the transfer read having in_bounds = false
. The dynamic shape was due to not being able to apply masking (target doesn't support it) or peeling (related to https://github.com/openxla/iree/issues/16406), not due to a truly dynamic case so it should be a valid case (e.g., the number of i4
elements is even but it's not a multiple of the vector size so a dynamic dim is generated in the absence of masking/peeling).
Anyway, let's focus on the common case since that is the main blocker right now! Is it clear how to move forward, @Max191? Do you have code for this or planning to help with it? Thanks!
Anyway, let's focus on the common case since that is the main blocker right now! Is it clear how to move forward, @Max191? Do you have code for this or planning to help with it? Thanks!
Sorry, I lost this thread for a few days. I found my old pattern, but I'm not sure it is what we would want to do. I believe it would rewrite this to something like:
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1]
: memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>
%41 = vector.load %collapse_shape_5[%19, %21, %c0, %c0] : memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>, vector<64xi4>
IIRC, this worked fine when I tested it for a similar case, and this would also get rid of the collapse_shape, but it does create a strange vector.load.
Anyway, let's focus on the common case since that is the main blocker right now! Is it clear how to move forward, @Max191? Do you have code for this or planning to help with it? Thanks!
Sorry, I lost this thread for a few days. I found my old pattern, but I'm not sure it is what we would want to do. I believe it would rewrite this to something like:
%subview_4 = memref.subview %7[%15, 0, 0, 0] [4, 512, 8, 8] [1, 1, 1, 1] : memref<256x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> to memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>> %41 = vector.load %collapse_shape_5[%19, %21, %c0, %c0] : memref<4x512x8x8xi4, strided<[32768, 64, 8, 1], offset: ?>>, vector<64xi4>
IIRC, this worked fine when I tested it for a similar case, and this would also get rid of the collapse_shape, but it does create a strange vector.load.
I would expect the vector.load
to generate a vector<1x1x8x8xi4>
and shape cast it down to vector<64xi4>
. So at this stage this seems OK. We are just translating through and everything gets linearized anyway.
Hmm... A bit borderline to me. This is probably something that would fit into the versatility of transfer reads/write but I'm doubtful of whether we want this at vector load/store level, where things are expected to be simpler. There's also the fact that this folding is actually undoing the tensor collapse on the memref so I think we shouldn't consider this a canonical form or requirement to apply emulation. What are the main blockers to support emulation on tensor collapse directly?
Hmm... A bit borderline to me. This is probably something that would fit into the versatility of transfer reads/write but I'm doubtful of whether we want this at vector load/store level, where things are expected to be simpler. There's also the fact that this folding is actually undoing the tensor collapse on the memref so I think we shouldn't consider this a canonical form or requirement to apply emulation. What are the main blockers to support emulation on tensor collapse directly?
Like I said, you cannot do emulation for multi-dimensional types. Adding that is just for a very specific use case and doesnt actually solve the problem.
Like I said, you cannot do emulation for multi-dimensional types. Adding that is just for a very specific use case and doesnt actually solve the problem.
Sorry, I'm still missing this point. Why we can't apply emulation to multi-dim memrefs? Could you provide an example for my understanding? Thanks in advance!
Like I said, you cannot do emulation for multi-dimensional types. Adding that is just for a very specific use case and doesnt actually solve the problem.
Sorry, I'm still missing this point. Why we can't apply emulation to multi-dim memrefs? Could you provide an example for my understanding? Thanks in advance!
You cannot emulate a memref<?x?xi4, strides=[?, ?], offset=?>
. Everything else is very special casey and just punting the problem down stream or working around a real issue. The only realistic way of supporting emulation is on a type memref<?xi4, offset=?>
Thanks for clarifying. I think there must have been some miscommunication. We all agree that a memref with multi-dimensional dynamic shapes can't be emulated. The discussion point is if we should only support cases where all the address computation (i.e., memref.subview
, memref.collapse_shape
, ...) can be folded into the vector load/store or we should also add emulation support for cases where, while still valid from the emulation point of view, their address computation can't be folded into a single vector load/store.
In the example I provided above, it's easy to prove that the memref.subview
and the memref.collapse_shape
are valid from the emulation point of view (i.e., collapsed dimensions are contiguous, number of elements of the trailing dimension is even, etc.) so we just have to rewrite them using the corresponding i8
offset. That would be cleaner that just enforcing a folding of those ops into a vector.load/store that would end up reading beyond the dimension bounds.
It's interesting because I'm now thinking that folding the address computation into the vector load/store could actually lead to emulation cases that shouldn't be supported. For example, we may think that a memref<6xi4>
can be emulated, that that wouldn't be the case if the memref is coming from an original memref<2x3xi4>
. Same for a memref<?xi4>
, assuming that it can be collapsed from a memref<?x?xi4>
case. How do we make sure that these cases are not accidentally emulated?
In the example I provided above, it's easy to prove that the memref.subview and the memref.collapse_shape are valid from the emulation point of view (i.e., collapsed dimensions are contiguous, number of elements of the trailing dimension is even, etc.) so we just have to rewrite them using the corresponding i8 offset. That would be cleaner that just enforcing a folding of those ops into a vector.load/store that would end up reading beyond the dimension bounds.
+1 on what Diego said.. It would be very awesome if we can emulate subtypes for memref.subview + memref.collapse_shape
case. We will really need that to generate non-scalar load/store.. In most of CPUs, we usually want to flatten inner most dimensions. Diego and I have been looking at flattening. Introducing a memref.collapse_shape
sounds reasonable to me. We don't want a vector.shape_cast
for many reasons.
Thanks for clarifying. I think there must have been some miscommunication. We all agree that a memref with multi-dimensional dynamic shapes can't be emulated. The discussion point is if we should only support cases where all the address computation (i.e.,
memref.subview
,memref.collapse_shape
, ...) can be folded into the vector load/store or we should also add emulation support for cases where, while still valid from the emulation point of view, their address computation can't be folded into a single vector load/store.In the example I provided above, it's easy to prove that the
memref.subview
and thememref.collapse_shape
are valid from the emulation point of view (i.e., collapsed dimensions are contiguous, number of elements of the trailing dimension is even, etc.) so we just have to rewrite them using the correspondingi8
offset. That would be cleaner that just enforcing a folding of those ops into a vector.load/store that would end up reading beyond the dimension bounds.
I don't see how this is cleaner. This is supporting a very narrow case. At the end of the day really subviews are "not a thing". At the LLVM/SPIRV level you only have loads and stores. So eventually all the subviews have to fold away. There is no magic here. You could only write emulation after linearization to get the offset + base pointer. Think of how you would handle i5 or i3 with multi dimensional memrefs? I don't see how you would do that. But if you linearized the memrefs you can support any bit width.
It's interesting because I'm now thinking that folding the address computation into the vector load/store could actually lead to emulation cases that shouldn't be supported. For example, we may think that a
memref<6xi4>
can be emulated, that that wouldn't be the case if the memref is coming from an originalmemref<2x3xi4>
. Same for amemref<?xi4>
, assuming that it can be collapsed from amemref<?x?xi4>
case. How do we make sure that these cases are not accidentally emulated?
You are actually making the case for supporting only linearized memrefs. In a memref<?xi4>
you have all the information for emulation. If you try to emulate memref<?x?xi4>
unless you prove it is contiguous you cannot emulate this (and if it is contiguous then it is valid to collapse). So if you stick to 1d memrefs as a prerequisite for emulation, you don't have any ambiguity.
In the example I provided above, it's easy to prove that the memref.subview and the memref.collapse_shape are valid from the emulation point of view (i.e., collapsed dimensions are contiguous, number of elements of the trailing dimension is even, etc.) so we just have to rewrite them using the corresponding i8 offset. That would be cleaner that just enforcing a folding of those ops into a vector.load/store that would end up reading beyond the dimension bounds.
+1 on what Diego said.. It would be very awesome if we can emulate subtypes for
memref.subview + memref.collapse_shape
case. We will really need that to generate non-scalar load/store.. In most of CPUs, we usually want to flatten inner most dimensions. Diego and I have been looking at flattening. Introducing amemref.collapse_shape
sounds reasonable to me. We don't want avector.shape_cast
for many reasons.
I don't follow. This is mixing concerns in my book. Subviews + collapse shape to get a 1d memrefs? If you have that why can't you fold these into loads/stores. I am not saying anything with respect to scalar or vector load stores. At this level, if you have scalar loads you fold things with memref.load/store. If you have vector loads you fold things with vector.load/store. At this level f you are dealing with multi dimensional vectors, you aren't at a place in the compilation flow that needs to introduce emulation