[Codegen][TopK] Faster TopK implementation
This is a hybrid (vector and scalars) implementation of TopK. The load and compare of elements is done using vectors and the sorting/writing to result is done using scalar registers. The current implementation is pure scalar.
This results in significant speedup of the function. The speedup scales with the reduction of the ratio of output/input sizes.
This implementation seems to be about 10% faster even on smaller data sets where the output/input ration is 1 (see below that benchmark for input 1x32->1x32, which is basically using TopK for sorting).
For bigger data sets we observe even bigger speedups:
- For 2x1024 -> 2x40 the speedup is 9.59x
- For 1x256128 -> 1x40 the speedup is 6.12x.
Benchmark numbers - each benchmark was ran 10 times.
2x1024 -> 2x40
Vector BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.075 ms 0.103 ms 10 items_per_second=13.2933k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.075 ms 0.102 ms 10 items_per_second=13.3588k/s
Scalar BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.719 ms 1.32 ms 10 items_per_second=1.39137k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.718 ms 1.32 ms 10 items_per_second=1.39307k/s
1x32 -> 1x32
Vector BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.042 ms 0.046 ms 10 items_per_second=23.9746k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.042 ms 0.046 ms 10 items_per_second=23.9371k/s
Scalar BM_custom_call_topk_tuple_16/process_time/real_time_mean 0.051 ms 0.055 ms 10 items_per_second=19.5375k/s BM_custom_call_topk_tuple_16/process_time/real_time_median 0.051 ms 0.055 ms 10 items_per_second=19.6053k/s
1x256128 ->1x40
Vector BM_custom_call_topk_tuple_16/process_time/real_time_mean 3.79 ms 3.67 ms 10 items_per_second=264.076/s BM_custom_call_topk_tuple_16/process_time/real_time_median 3.78 ms 3.67 ms 10 items_per_second=264.459/s
Scalar BM_custom_call_topk_tuple_16/process_time/real_time_mean 23.3 ms 23.2 ms 10 items_per_second=42.9361/s BM_custom_call_topk_tuple_16/process_time/real_time_median 23.2 ms 23.1 ms 10 items_per_second=43.0286/s
@MaheshRavishankar Reverted the files. The flag controlling the use of TopK vectorization was originally in these files, but limitations to the current implementation required moving it in KernelDispatch. I thought the formatter was formatting not only C/C++ files (the way it is in the Phoenix compiler - it formats also the *.td files at least for identation and line length), so I ran the formatter through them. My bad!
Also please add some high-level description of the IR changes you are looking to do (like the MLIR before and after) which happens with lit tests too, but make it easier to read.
+1, sorry that I haven't looked into details yet, but I wonder why there are scf ops in vectorization? I don't expect to generate any scf ops during vectorization.
@hanhanW No worries! I know it is a lot of code... I have a PR request for making the createReadOrMaskedRead a utility function in MLIR - https://github.com/llvm/llvm-project/pull/89119 Adding LitTests at the moment. I'll add some more descriptions in the code of what it is doing, although I think there are some. I thought to give some more details re: the scf:Ifs etc. I stated in the initial comment that this is a hybrid approach. What it means is that the loading and cmp of the inputs are done in vectors, and when decision is made that an element is to be inserted in the output, the implementation lowers to scalar to find the insertion index and to insert the elements in sorted order into the output. Unfortunately, the second part can't really be done in vectors because there is an implied loop carry dependency - the sort order. We get the good speedup, because there is no need to try to insert every element. The algorithm is designed to have a bigger relative speedup with the increase of the ration input size/output size. The algorithm maintains state that is represented as two variables - smallestElem - the smallest element added to the output and addedElems - the number of elements added to the output. The first scf::if is to initialize the smallestElem. Then, vector size of elements is loaded and are compared to a vector broadcast of the smallestElem. A reduction is made on the result of the comparee and if there is elements in the input that are bigger than the smallest element added or the output is not fully filled (I think this was the second scf::if) we go into insertion of the element in the output part of the algirithm. First, the element that is bigger than the smallest element is identified and extracted and it's index. Then, we find the index in the output where the element needs to be added - the second scf::For in the codegen (unfortunately, the scf::For doesn't support bailing out from the middle of the loop, so I had to add a flag - as a iter_arg to keep track if the insertionPosition is found. Even with that there are some amount of cycles wasted, since we need to iterate to the end no matter what. Then, the element is inserted in the output, after shifting the rest of the elements (the third scf::For loop). There is a special case that tests if the element is added because the output is not filled, and the element is not bigger than the smallest element added - in this case shift of elements is not needed and the last element is added at the end of the output. Since scf::For is not supporting negative steps (from the docs), I was still able to remove the need of temp var and extra assignments by using iter_args in the third (shifting) loop with the old value from the previous index (iteration). I hope that answers your questions. If not, ask... :)
It would be very helpful if you can share the IR before and after vectorization. I'd like to take a look and see if I can provide any suggestions. Maybe there are other passes you can reuse, then we don't need to make it complicated. LinalgExt is designed to be very similar to linalg. Generating scf ops during vectorization is a big red flag to me because we never use them in vectorizer.
I would add the LitTest and make the PR official as well. They will provide the before and after code... I'll paste it here for a small testcase of 1x32xf32 inputs and outputs. The algorithm doesn't require to use vector register at all. I'm sure we will have very significant speedup if we are to just use the algorithm with scalars. On the other hand the use of vector registers to load and compare saves about 30 cycles for each 16 elements - one load and compare of 16 elems, instead of doing element by element. That is why the vector registers are useful, IMHO in this case... It saves cycles...
The codegen basically replaces the topk operation with the codegen I described above...
// -----// IR Dump Before GeneralizeLinalgNamedOps (iree-global-opt-generalize-linalg-named-ops) //----- // util.func public @custom_call_topk_test_32() -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @custom_call_topk_test_32() -> (%output0: tensor<1x32xf32>, %output1: tensor<1x32xi32>)"}} { %c0_i32 = arith.constant 0 : i32 %cst = arith.constant 0xFF800000 : f32 %cst_0 = arith.constant dense<[[1.300000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.400000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.800000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.230000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.200000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]]> : tensor<1x32xf32> %0 = util.optimization_barrier %cst_0 : tensor<1x32xf32> %1 = tensor.empty() : tensor<1x32xf32> %2 = tensor.empty() : tensor<1x32xi32> %3 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x32xf32>) -> tensor<1x32xf32> %4 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<1x32xi32>) -> tensor<1x32xi32> %5:2 = iree_linalg_ext.topk dimension(1) ins(%0 : tensor<1x32xf32>) outs(%3, %4 : tensor<1x32xf32>, tensor<1x32xi32>) { ^bb0(%arg0: f32, %arg1: f32): %8 = arith.cmpf ogt, %arg0, %arg1 : f32 iree_linalg_ext.yield %8 : i1 } -> tensor<1x32xf32>, tensor<1x32xi32> %6 = hal.tensor.export %5#0 "output0" : tensor<1x32xf32> -> !hal.buffer_view %7 = hal.tensor.export %5#1 "output1" : tensor<1x32xi32> -> !hal.buffer_view util.return %6, %7 : !hal.buffer_view, !hal.buffer_view }
// IR Dump After GenericVectorization (iree-codegen-generic-vectorization) //----- //
func.func @custom_call_topk_test_32_dispatch_0_topk_1x32xf32() attributes {translation_info = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x32xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1x32xf32>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c128) : !flow.dispatch.tensor<readwrite:tensor<1x32xi32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x32xf32>> -> tensor<1x32xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x32xf32>> -> tensor<1x32xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1, 32], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x32xi32>> -> tensor<1x32xi32>
%true = arith.constant true
%cst = arith.constant 0.000000e+00 : f32
%c0_0 = arith.constant 0 : index
%6:5 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %4, %arg2 = %5, %arg3 = %true, %arg4 = %cst, %arg5 = %c0_0) -> (tensor<1x32xf32>, tensor<1x32xi32>, i1, f32, index) {
%extracted_slice = tensor.extract_slice %3[0, %arg0] [1, 16] [1, 1] : tensor<1x32xf32> to tensor<1x16xf32>
%extracted_slice_1 = tensor.extract_slice %arg1[0, %arg0] [1, 32] [1, 1] : tensor<1x32xf32> to tensor<1x32xf32>
%extracted_slice_2 = tensor.extract_slice %arg2[0, %arg0] [1, 32] [1, 1] : tensor<1x32xi32> to tensor<1x32xi32>
%7:4 = scf.if %arg3 -> (tensor<1x32xf32>, tensor<1x32xi32>, f32, index) {
%c0_8 = arith.constant 0 : index
%extracted = tensor.extract %extracted_slice[%c0_8, %c0_8] : tensor<1x16xf32>
%c0_9 = arith.constant 0 : index
scf.yield %arg1, %arg2, %extracted, %c0_9 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
} else {
scf.yield %arg1, %arg2, %arg4, %arg5 : tensor<1x32xf32>, tensor<1x32xi32>, f32, index
}
%cst_3 = arith.constant 0xFF800000 : f32
%c0_4 = arith.constant 0 : index
%8 = vector.transfer_read %extracted_slice[%c0_4, %c0_4], %cst_3 {in_bounds = [true, true]} : tensor<1x16xf32>, vector<1x16xf32>
%9 = vector.broadcast %7#2 : f32 to vector<1x16xf32>
%10 = arith.cmpf ogt, %8, %9 : vector<1x16xf32>
%c32_5 = arith.constant 32 : index
%11 = arith.cmpi slt, %7#3, %c32_5 : index
%12 = scf.if %11 -> (vector<1x16xi1>) {
%true_8 = arith.constant true
%15 = vector.broadcast %true_8 : i1 to vector<1x16xi1>
scf.yield %15 : vector<1x16xi1>
} else {
scf.yield %10 : vector<1x16xi1>
}
%false = arith.constant false
%13 = vector.multi_reduction
The numbers are really impressive, really awesome work. I'm happy to learn the context, and see how we land it to IREE.
The approach taken here is basically to generate a custom vector kernel, or, as we call it for some other ops (e.g., pack/unpack), apply "direct" vectorization. It's true that the code has some complexity but it's implementing a custom sort algorithm so certain level of complexity is expected.
Can we have a doc/RFC to describe the algorithm? I think people (like me) can learn from it when they look at the implementation details. And if there are new people coming to this land, we can point them to the RFC.
Can we have a doc/RFC to describe the algorithm? I think people (like me) can learn from it when they look at the implementation details. And if there are new people coming to this land, we can point them to the RFC.
That's a good idea. If it's a doc related issue I think it's important that the doc stays close to the code (also thinking that this could eventually be upstreamed). I think that having a detailed description of the algorithm including some snippets of the generated code and a few execution examples would be very helpful!
RFC issue: https://github.com/iree-org/iree/issues/17143
Wow, this got rebased really wrong. Please fix. If doing a rebase you're uncomfortable with you can use git log (or any other git visualizer) locally before pushing. I'm guessing your local main hadn't been updated in ages and you rebased on that instead of origin/main.
@benvanik Thanks! Yeah, I noticed after pushing it and fixing the diff (checkout before the rebase). It was a surprise that the rebase did what it did... I did update with the latest of upstrem (origin)/main before the rebase, and most of the files affected were not chaged by me (and no any conflics to resolve), which was completely unecxpected to be modified... Again, thanks for pointing it out!!!
Made changes to the original PR by:
- Moved the SCF lowering of TopK out of the GenericVectorize to a new pass TopkLowering. a. A new StringAttr was introduced to control the lowering process at different levels. The attribute is assigned to the TopkOp, denuting that a TopkLowering is needed. The attribute is added in the same way (and at the same place in the code) as mmt4 operation (CUtils.h/cpp).
- Moved the tests added for the original test to new tests files.
Did alternative implementation as a microkernel, at a later point in the translation pipeline. The microkernel implementation exhibited few issues that made sense to have the implementation as a simple SCF lowering in a separate pass:
- The microkernel implementation was established as a separate function and the compiler implementa=ed a frame for it. This turns out to be quite expensive for small impot sizes TopK function calls. For example, a TopK invoked on a register size (16 - 32 bit values) number of elements exhibited about 17.4% regression compared to the SCF lowering solution. The reason being that the frame setup and tear down code was significan size compared to the usefull operations. Examining the code, there was really no way to inline the microkernel (maybe I'm missing something).
- Since the microkernel is treated as a completely separate function, the Loop Unroling Opimizer (and generally the Loop optimizers - hoisting, etc.) was not run across the boundary of the microkernel and the caller of the microkernel. Thus, for simple cases (one iteration nested loop) the loop was not removed (and for other simple cases was not enroled). This resulted in extra instruction for the loop prolog and epilog, or extra checks were needed at runtime and dispatching to different versions of the kernel - one with no loop and one with loop. Either way, that resulted in non-optimal code with extra instruction at runtime.
- The microkernel implementations relied on some intrinsics that are not HW platform independent. SCF lowering generates code that on higher level don't rely on HW intrinsics and is lowerable to any backend (LLVM based or not) that has support for SCF.
Sorry, haven't had time to review this, but will do so this week.
@MaheshRavishankar Whenever you get a chance. Thank you very much!!
@MaheshRavishankar Thanks!!! I'll reach out to discuss possible changes to make it more readable. Thanks again!
@hanhanW I think I address the requests you have soe time ago. The f;ag is changed now. It was requested to be there by another review.