RFC: TopK perf improvements
Request description
This is a RFC proposal for speeding up the current TopK implementation.
Problem that this proposal addresses In one LLM model we see the TopK function taking fairly long time in the overall execution time - about 7%. The function is used on inputs of 4x8000xf32 and 1x256128xf32 with outputs of 4x40xf32 and 4x40xi32 and 1x40xf32 and 1x40xi32, respectively. After analysis of the current implementation, we discovered that the current implementation is sorting every element in the output, which is suboptimal.
Proposal Generate a function that avoids attempting to sort every input element. The function should scale (yield better perf) with smaller ratio of size of outputs/size of inputs. If we maintain the smallest element added to the output and the number of elements added, the elements smaller than the smallest element added to the output, should not be considered for insertion (sorting) in the output. Further, the reading and comparison of the input elements to the smallest added element can be done using vector registers (one with the input and one with the broadcasted smallest element). The use of a vector registers saves up to 30 cycles per every 16 elements (assuming a vector register contains 16 input elements). To proper tiling of the output, the DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize tiling pipeline, which tiles the input vector on the size of elements in the input (1x16xf32). Note: In the initial implementation if the input size is not multiple of the number of elements in the vector register, the current implementation will be delegated to. This will be fixed in subsequent revisions of the code. The implementation is a hybrid implementation where the reading and comparison is done using vectors and the insertion/sorting of the output is done using scalars. Further, an extra optimization can be achieved if we keep the number of added elements in the output. In case, the output is not filled completely yet and the element being considered is smaller than the smallest element added to the output, there is no need to shift the output elements in order to sort and the new element can be added at the very end of the output and the output will still be sorted.
Prototype implementation There is currently a prototype/initial implementation of this algorithm that yields the following performance improvements:
The algorithm described above results in significant speedup of the function. The speedup scales with the reduction of the ratio of output/input sizes. This implementation seems to be about 10% faster even on smaller data sets where the output/input ratio is 1 (see below that benchmark for input 1x32->1x32, which is basically using TopK for sorting). For bigger datasets we observe even bigger speedups: For 2x1024 -> 2x40 the speedup is 9.59x For 1x256128 -> 1x40 the speedup is 6.12x. Benchmark numbers - each benchmark was ran 10 times. 2x1024 -> 2x40 Using the new algorithm with vector registers BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.075 ms 0.103 ms 10 items_per_second=13.2933k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.075 ms 0.102 ms 10 items_per_second=13.3588k/s Current scalar implementation BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.719 ms 1.32 ms 10 items_per_second=1.39137k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.718 ms 1.32 ms 10 items_per_second=1.39307k/s 1x32 -> 1x32 Using the new algorithm with vector registers BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.042 ms 0.046 ms 10 items_per_second=23.9746k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.042 ms 0.046 ms 10 items_per_second=23.9371k/s Current scalar implementation BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.051 ms 0.055 ms 10 items_per_second=19.5375k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.051 ms 0.055 ms 10 items_per_second=19.6053k/s 1x256128 ->1x40 Using the new algorithm with vector registers BM_custom_call_topk_tuple_16/process_time/real_time_mean 3.79 ms 3.67 ms 10 items_per_second=264.076/s BM_custom_call_topk_tuple_16/process_time/real_time_median 3.78 ms 3.67 ms 10 items_per_second=264.459/s Current scalar implementation BM_custom_call_topk_tuple_16/process_time/real_time_mean 23.3 ms 23.2 ms 10 items_per_second=42.9361/s BM_custom_call_topk_tuple_16/process_time/real_time_median 23.2 ms 23.1 ms 10 items_per_second=43.0286/s
Implementation notes The implementation is done in a pure codegen approach. There have been some bugs and limitations in the base libraries and dialects discovered in the process, for which there were some workarounds. Some of them are:
- Inability to break out of a scf::For (that required to have extra iter_args to be added, so the loop continues to the end with no modification of returns - this is used in the initialization of the smallest element added and the loop for calculating the insertion position in the output of the inserted element)
- Not having support for negative steps in sfc::For (in the shifting elements loop before insertion extra itrer_args were added of the element to be shifted to the next element, thus achieving the same result as negative step loop and alleviating the need for temp var assignments).
Here is the code before and after the code transformation: Before
func.func @custom_call_topk_tuple_f10_dispatch_0_topk_1x32xf32() attributes {translation_info = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x32xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1x32xf32>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c128) : !flow.dispatch.tensor<readwrite:tensor<1x32xi32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x32xf32>> -> tensor<1x32xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x32xf32>> -> tensor<1x32xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x32xi32>> -> tensor<1x32xi32>
%6:2 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %4, %arg2 = %5) -> (tensor<1x32xf32>, tensor<1x32xi32>) {
%extracted_slice = tensor.extract_slice %3[0, %arg0] [1, 16] [1, 1] : tensor<1x32xf32> to tensor<1x16xf32>
%extracted_slice_0 = tensor.extract_slice %arg1[0, %arg0] [1, 32] [1, 1] : tensor<1x32xf32> to tensor<1x32xf32>
%extracted_slice_1 = tensor.extract_slice %arg2[0, %arg0] [1, 32] [1, 1] : tensor<1x32xi32> to tensor<1x32xi32>
%7:2 = iree_linalg_ext.topk {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [0, 16]]>} dimension(1) ins(%extracted_slice : tensor<1x16xf32>) outs(%extracted_slice_0, %extracted_slice_1 : tensor<1x32xf32>, tensor<1x32xi32>) {
^bb0(%arg3: f32, %arg4: f32):
%8 = arith.cmpf ogt, %arg3, %arg4 : f32
iree_linalg_ext.yield %8 : i1
} -> tensor<1x32xf32>, tensor<1x32xi32>
%inserted_slice = tensor.insert_slice %7#0 into %arg1[0, %arg0] [1, 32] [1, 1] : tensor<1x32xf32> into tensor<1x32xf32>
%inserted_slice_2 = tensor.insert_slice %7#1 into %arg2[0, %arg0] [1, 32] [1, 1] : tensor<1x32xi32> into tensor<1x32xi32>
scf.yield %inserted_slice, %inserted_slice_2 : tensor<1x32xf32>, tensor<1x32xi32>
}
flow.dispatch.tensor.store %6#0, %1, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : tensor<1x32xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x32xf32>>
flow.dispatch.tensor.store %6#1, %2, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : tensor<1x32xi32> -> !flow.dispatch.tensor<readwrite:tensor<1x32xi32>>
return
}
Codegen with comments, using the algorithm proposed above
func.func @custom_call_topk_tuple_f10_dispatch_0_topk_1x32xf32() attributes {translation_info = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x32xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1x32xf32>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c128) : !flow.dispatch.tensor<readwrite:tensor<1x32xi32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x32xf32>> -> tensor<1x32xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x32xf32>> -> tensor<1x32xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x32xi32>> -> tensor<1x32xi32>
// Most of the scf structs (for and if) have the same iter_args/returns output0,
// output1, smallestElem and elemsAdded
// Note: the top loop has an extra arg just before the smallestElem of type i1 denoting if // the amallestElem has been initialized.
// Add extra iter_args for initSmallest, smallestElem and addedElems
%true = arith.constant true
%cst = arith.constant 0.000000e+00 : f32
%c0_0 = arith.constant 0 : index
%6:5 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %4, %arg2 = %5, %arg3 = %true, %arg4 = %cst, %arg5 = %c0_0) -> (tensor<1x32xf32>, tensor<1x32xi32>, i1, f32, index) {
%extracted_slice = tensor.extract_slice %3[0, %arg0] [1, 16] [1, 1] : tensor<1x32xf32> to tensor<1x16xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [1, 32] [1, 1] : tensor<1x32xf32> to tensor<1x32xf32>
%extracted_slice_2 = tensor.extract_slice %arg2[0, %arg0] [1, 32] [1, 1] : tensor<1x32xi32> to tensor<1x32xi32>
// Initialize the smallest element to the first element in the input, if not initialized.
%7:4 = scf.if %arg3 -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
%c0_8 = arith.constant 0 : index
%extracted = tensor.extract %extracted_slice[%c0_8, %c0_8] : tensor<1x16xf32>
%c0_9 = arith.constant 0 : index
scf.yield %arg1, %arg2, %extracted, %c0_9 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
} else {
scf.yield %arg1, %arg2, %arg4, %arg5 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
%cst_3 = arith.constant 0xFF800000 : f32
%c0_4 = arith.constant 0 : index
// Read the imput elements in a vector register.
%8 = vector.transfer_read %extracted_slice[%c0_4, %c0_4], %cst_3 {in_bounds = [true, true]} : tensor<1x16xf32>, vector<1x16xf32>
// Broadcast smallestElem in another vector register.
%9 = vector.broadcast %7#2 : f32 to vector<1x16xf32>
// Check to see if there are elements in teh input that are
// greater than the smallest element added.
%10 = arith.cmpf ogt, %8, %9 : vector<1x16xf32>
%c32_5 = arith.constant 32 : index
// Check to see if the outyput is fully filled.
%11 = arith.cmpi slt, %7#3, %c32_5 : index
// If not, set the mask of the vector registers to 1s,
// so the ncoming elements will still be added to the output
// until is fully filled.
%12 = scf.if %11 -> (vector<1x16xi1>) {
%true_8 = arith.constant true
%15 = vector.broadcast %true_8 : i1 to vector<1x16xi1>
scf.yield %15 : vector<1x16xi1>
} else {
scf.yield %10 : vector<1x16xi1>
}
%false = arith.constant false
// Check to see if any input elemets need to be added to the output.
// Reduce over the mask register with or to seeif there are elements that need
// to be added
%13 = vector.multi_reduction <or>, %12, %false [0, 1] : vector<1x16xi1> to i1
// If to insert elements if there are nay needing insertion.
%14:4 = scf.if %13 -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
%c0_8 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// Iterate over the mask elements.
%15:4 = scf.for %arg6 = %c0_8 to %c16 step %c1 iter_args(%arg7 = %7#0, %arg8 = %7#1, %arg9 = %7#2, %arg10 = %7#3) -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
// Get the mask element for each element in the mask register.
// Need a shape cast because the vector.extractelement operates on vectors
// (arrays) only.
// The vector::extract currently doesn't support unitDim shapes.
%16 = vector.shape_cast %12 : vector<1x16xi1> to vector<16xi1>
%17 = vector.extractelement %16[%arg6 : index] : vector<16xi1>
// If the mask is true, the element needs insertion.
%18:4 = scf.if %17 -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
// Extract the element that needs to be inserted in the output.
%19 = vector.shape_cast %8 : vector<1x16xf32> to vector<16xf32>
%20 = vector.extractelement %19[%arg6 : index] : vector<16xf32>
// Get the index of the input element by adding the induction vars
// of the two loops - the one iterating over input with step 16 and
// the one iterating over the ekements in the vector.
%21 = arith.addi %arg6, %arg0 : index
// A flag to denote if the insertion index was found.
%true_9 = arith.constant true
%c0_10 = arith.constant 0 : index
%c0_11 = arith.constant 0 : index
%c1_12 = arith.constant 1 : index
// Loop to find the insertion index for the element that is being added.
%22:2 = scf.for %arg11 = %c0_11 to %arg10 step %c1_12 iter_args(%arg12 = %true_9, %arg13 = %c0_10) -> (i1, index) {
%true_13 = arith.constant true
%26 = arith.cmpi eq, %arg12, %true_13 : i1
// If index not found...
%27:2 = scf.if %26 -> (i1, index) {
%c0_14 = arith.constant 0 : index
// ... extract the element from the output and check to see if this
// is the insertion index of the new element.
%extracted = tensor.extract %arg7[%c0_14, %arg11] : tensor<1x32xf32>
%28 = arith.cmpf olt, %extracted, %20 : f32
%true_15 = arith.constant true
%29 = arith.cmpi eq, %28, %true_15 : i1
%30:2 = scf.if %29 -> (i1, index) {
%c0_16 = arith.constant 0 : index
%31 = arith.addi %arg11, %c0_16 : index
%false_17 = arith.constant false
scf.yield %false_17, %31 : i1, index
} else {
// If the index was found, just iterate to the end with the for loop.
// Note: there is comment in the code for this to be removed when
// scf::For supports break.
scf.yield %arg12, %arg13 : i1, index
}
scf.yield %30#0, %30#1 : i1, index
} else {
// If no insertion needed, just exit.
scf.yield %arg12, %arg13 : i1, index
}
scf.yield %27#0, %27#1 : i1, index
}
%23 = arith.cmpi eq, %arg10, %c32_5 : index
%24 = arith.andi %23, %22#0 : i1
// Check to see if the element doesn'y need to be added.
%25:4 = scf.if %24 -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
scf.yield %arg7, %arg8, %arg9, %arg10 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
} else {
// Check to see if the element is appended to the end of the output
// (output not fully filled yet - no insertion index found.
%26:4 = scf.if %22#0 -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
// Insert the elem and its index at the end
// and increment the number of added elements.
%c0_13 = arith.constant 0 : index
%inserted = tensor.insert %20 into %arg7[%c0_13, %arg10] : tensor<1x32xf32>
%27 = arith.index_cast %21 : index to i32
%inserted_14 = tensor.insert %27 into %arg8[%c0_13, %arg10] : tensor<1x32xi32>
%c1_15 = arith.constant 1 : index
%28 = arith.addi %arg10, %c1_15 : index
scf.yield %inserted, %inserted_14, %arg9, %28 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
} else {
// The element needs to be inserted
%27 = arith.cmpi eq, %arg10, %c32_5 : index
%28 = scf.if %27 -> (index) {
// If it is inserted at the end, just add it there, pushing
// the current last element out.
// Setting the insertion index and num added elems in a way,
// that the shifting loop below will not execute.
%c1_19 = arith.constant 1 : index
%32 = arith.subi %arg10, %c1_19 : index
scf.yield %32 : index
} else {
scf.yield %arg10 : index
}
%c1_13 = arith.constant 1 : index
%c0_14 = arith.constant 0 : index
// set the values of the elements that need to be shifted right as
// iter_args.
%extracted = tensor.extract %arg7[%c0_14, %22#1] : tensor<1x32xf32>
%extracted_15 = tensor.extract %arg8[%c0_14, %22#1] : tensor<1x32xi32>
%c1_16 = arith.constant 1 : index
// Shift the elements from insertionIndex to the end to the right.
%29:4 = scf.for %arg11 = %22#1 to %28 step %c1_16 iter_args(%arg12 = %arg7, %arg13 = %arg8, %arg14 = %extracted, %arg15 = %extracted_15) -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, i32) {
%c1_19 = arith.constant 1 : index
%32 = arith.addi %arg11, %c1_19 : index
%c0_20 = arith.constant 0 : index
// Get the next elements and set it as iter args before...
%extracted_21 = tensor.extract %arg12[%c0_20, %32] : tensor<1x32xf32>
%extracted_22 = tensor.extract %arg13[%c0_20, %32] : tensor<1x32xi32>
// ... setting the previous element values (in the iter_args)...
%inserted_23 = tensor.insert %arg14 into %arg12[%c0_20, %32] : tensor<1x32xf32>
%inserted_24 = tensor.insert %arg15 into %arg13[%c0_20, %32] : tensor<1x32xi32>
// Update the iter_args.
scf.yield %inserted_23, %inserted_24, %extracted_21, %extracted_22 : tensor<1x32xf32>, tensor<1x32xi32>, f32, i32
}
// Get the current last element and update the iter_arg of the outmost loop.
%extracted_17 = tensor.extract %29#0[%c0_14, %28] : tensor<1x32xf32>
%30 = arith.addi %28, %c1_13 : index
%inserted = tensor.insert %20 into %29#0[%c0_14, %22#1] : tensor<1x32xf32>
%31 = arith.index_cast %21 : index to i32
%inserted_18 = tensor.insert %31 into %29#1[%c0_14, %22#1] : tensor<1x32xi32>
scf.yield %inserted, %inserted_18, %extracted_17, %30 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
scf.yield %26#0, %26#1, %26#2, %26#3 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
scf.yield %25#0, %25#1, %25#2, %25#3 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
} else {
scf.yield %arg7, %arg8, %arg9, %arg10 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
scf.yield %18#0, %18#1, %18#2, %18#3 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
scf.yield %15#0, %15#1, %15#2, %15#3 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
} else {
scf.yield %7#0, %7#1, %7#2, %7#3 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
%inserted_slice = tensor.insert_slice %extracted_slice_1 into %arg1[0, %arg0] [1, 32] [1, 1] : tensor<1x32xf32> into tensor<1x32xf32>
%inserted_slice_6 = tensor.insert_slice %extracted_slice_2 into %arg2[0, %arg0] [1, 32] [1, 1] : tensor<1x32xi32> into tensor<1x32xi32>
%false_7 = arith.constant false
scf.yield %14#0, %14#1, %false_7, %14#2, %14#3 : tensor<1x32xf32>, tensor<1x32xi32>, i1, f32, index
}
flow.dispatch.tensor.store %6#0, %1, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : tensor<1x32xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x32xf32>>
flow.dispatch.tensor.store %6#1, %2, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : tensor<1x32xi32> -> !flow.dispatch.tensor<readwrite:tensor<1x32xi32>>
return
}
Feature PR: https://github.com/iree-org/iree/pull/17045
Thanks @LLITCHEV for the RFC. Some comments though a lot of details here are a bit over my head
The lowering looks very much like "hand-crafted IR", i.e. it is not a sequence of composable transformations that get us to the final state. It is for such canned-handcrafted sequence that the Ukernel path was plumbed through. There is one advantage here though. In terms of deployment you dont need to worry about a specific hardware ISA and writing ukernels for each. It is worth doing that for something like GEMM where getting the right hardware ISA is important.
So I suggest an intermediate solution (leveraging what you already have). For the most part we can piggy back along the ukernel path, i.e.
- Recognize the op you want to lower to a ukernel, here the top-k op and lower it to the
iree_codegen.ukernel.genericoperation - Follow the same path as all other microkernels take till it gets lowered to a function call (i.e. bufferization + lower ukernels to func call)
At this stage you will have an empty function which represents the top-k kernel. You can provide the implementation of that function (i.e. fill in the body of the function) as you want it and then just inline the function. Effectively instead of relying on linking at LLVM level to provide the ukernel function definition, you are just building the function body in MLIR and then lowering it to LLVM.
Thoughts? cc @jpienaar as well?
@MaheshRavishankar Thank you so much for the thoughtfull comments! I really appreciate it! I have had discussed these with Diego and following are some of the highlights of the conversations I have...
- First, I implemented this initially as a stand alone function in C and got very similar perf results. I thought the easiest way forward would be to just incorporate the C function as a RTL and just call it.
- It was suggested that implementing it as a scf/vector dialect consumer would be more in-line with the overall value of IREE - a crosplatform implementation that would run on any backend that IREE supports. I do agree with that, since this code should run on GPUs and other hardware we support as well, even HW that doesn't have a LLVM backend - like things that are targeted by scf and vector dialects only (not LLVM). (I have not attempeted it yet, but technically it should just work, although I won't be surprized if there are things that need fixing. There were several issues in different dialects that needed fixing, including missing operation conversions for lower dialect, before this codegen actually worked). I suspect there might be need to fix some things on some of the other backends).
- This implementation also made sense since all the other "canned" operations are done in similar way in the same file GenericVecrorization.cpp. The "canned" functions that are generated the same way are PadOp, PackOp, UnpackOp, and LinalgOp that are vectorized.
- I also have run into some limitation and correctness issues in the current implementations of the TopK function for the different platforms, and scf/vector solution, like this one would hopefully make moot. Some examples are: a. The current CPU implementation has no support for int32 types (it resulted in errors when I tried to compile). I had to disable some tests - look in https://github.com/iree-org/iree/blob/50a134a9c38df30b1c6cacb59416f7f671437b3f/tests/e2e/linalg_ext_ops/top-k-vector.mlir. It also doesn't comply with the layout_config, so if another layer of tiling is introduced for the input/output of the operation, the result will be incorrect. b. The tests are already disabled on Vulkan with a note that the codegen is incorrect. c. On Cuda, I had to disable my tests because it was producing wrong results. In particular, the file is https://github.com/iree-org/iree/blob/50a134a9c38df30b1c6cacb59416f7f671437b3f/tests/e2e/linalg_ext_ops/top-k-vector.mlir, and the test vector_call_topk_1x256 fails with duplicating elements in the output. (Issue: https://github.com/iree-org/iree/issues/17275)
/work/runtime/src/iree/modules/check/module.cc:372: Failure
Failed
Expected equality of these values. Contents does not match.
lhs:
1x40xi32=[135 135 169 169 247 247 233 233 3 3 7 7 11 11 15 15 19 19 23 23 27 27 31 31 35 35 39 39 43 43 47 47 51 51 55 55 59 59 63 63]
rhs:
1x40xi32=[135 169 247 233 3 7 11 15 19 23 27 31 35 39 43 47 51 55 59 63 67 71 75 79 83 87 91 95 99 103 107 111 115 119 123 127 131 139 143 147]```
IMHO, having the codegen done using scf/vector and other mlir dialects is beneficial, because it achieves one of the IREE goals of HW heterogeniality.
Any further thoughts really appreciated?
- This implementation also made sense since all the other "canned" operations are done in similar way in the same file GenericVecrorization.cpp. The "canned" functions that are generated the same way are PadOp, PackOp, UnpackOp, and LinalgOp that are vectorized.
I might be mistaken here, but the implementation here is not on par with the other ops you mention. Those do not generate a loop, etc. like you have here. The approach with those ops is generally
- Use tiling to generate the loops
- Lower to a "straight sequence" of operations for vectorization. For example, after the code you generate here you cannot "unroll the vectors" (or on GPU you cant distribute the vectors). Its pretty much from here just go down to LLVM (for the most part). So thats the phase ordering issue I see here and we can avoid that by just defering to where the code is generated much closer to LLVM lowering.
- I also have run into some limitation and correctness issues in the current implementations of the TopK function for the different platforms, and scf/vector solution, like this one would hopefully make moot.
Sorry, I am missing whether the implementation in your RFC is producing wrong results of IREE is producing wrong results at HEAD.
@MaheshRavishankar Thanks for the comments! First, this turned out longer than I expected. Sorry! We can discuss it more in Mai-Tai...
-
Sorry, my understanding (and Diego's from what I understood) is that the pass is generating "canned" operations - operations that generally don't fuse and are standalone piece of code. At least it looked that way to me. For TopK one can't really fuse anything because of the fact that you have "sort-of" loop carry dependency - the order. Not sure if it is the same for (some) of the rest of the operations - thinking about it a bit, pack/unpack, and even pad have the same property, I think. Order of things matter and the code probably is not really amenable to fusing and a lot of other optimizations. It is possible that I'm missing something, though. For distribution for the TopK inplementation, in thisproposal (and implementation) is done on the rows, because they are independent of each other. So, each row in the input can/is executed on different core/warp. For the existing implementation (at least on CPU) there is no distribution at all. I tried using some (in the same manner as described above) and the execution time doubled - it looked like the constant sorting was blowing up the caches. Having the codegen done in SCF and vector dialect also supports codegen for targets that have no LLVM backend. It will just work (hopefully, after all the necessary fixes down the pipeline, if some are required). And of course it works for all the targets that do have LLVM backend.
-
As far as the failures... I did try it fully disabling my code back when I added my new tests (that fail on Cuda). I just retried it to make sure. The failure still happens... https://github.com/iree-org/iree/actions/runs/8942205116/job/24564575269?pr=17045 My code is completely disabled and the failure is on HEAD from yesterday. And here is the issue I opened for that - https://github.com/iree-org/iree/issues/17275
-
To summarize, IMHO, having the implementation in MLIR builders instead of ukernel has the following benefits: a. It provides support for non-LLVM backend (I don't know how important is this for IREE?) It also provdes out of box support for new backends (devices). The different backends compilers (in LLVM and not only) are still free to optimize this code - splitting the lowered higher level scf loops the way they see fit. For this codegen in ukernel (LLVM), I think I need intrinsics for vector registers and last time I checked (sometime ago), they were not unified between backends fully yet (and very small subset. Some of them are still experimental (like cmpf, for example). b. Non-C numerical types (sub-byte types). It seems to me that these can be handled better in MLIR builder codegen. I don't really see how these will look in ukernels? I think we do care about these. c. Provides unified implementation between the different backends. There is one implementation that works on all the backends - consistent, correct, and fast. IMHO, this reduces also maintanance and the code will behave the same on all the backends. This assumes that all the backends are well optimized to compile fast MLIR code (which, I think they should be?)
As I said before, I haven't realized that this pass is only for "canned" operations that don't generate for loops. and ifs. If that is the case, i can certainly move the codegen for the TopK canned operation in a different pass.
Thanks!
We can discuss more next week in Mai-Tai... I am not disagreeing that its useful having the IR generated from MLIR, but I am saying this is essentially using a builder API to write a microkernel (which is totally OK). Just saying you do that as an "implementation of a function". Instead of the implementation coming from an llvm bit code file or library, you are building the MLIR implementation of it. Happy to take that, but this isnt a general vectorization recipe. I dont follow your comments on pack/pad etc. Those do not vectorize to scf + vector ops. They generate vector ops as straight line code. We can talk more on Mai Tai.