taichi icon indicating copy to clipboard operation
taichi copied to clipboard

Performance optimization for indexing a SharedArray

Open jim19930609 opened this issue 2 years ago • 4 comments

Taking a randomly accessible object mem as example. If we want to index mem with a loop index i plus some constant:

mem[i + const] -> mem_ptr + (i + const) * sizeof(dtype) 

The above indexing statement takes 2 x Add + 1 x Mul instructions.

Current Implementation

In scenarios especially related to the use of SharedArrays, the above indexing pattern can happen multiple times in a single loop. Then a single loop will look like:

mem[i + const0] -> mem_ptr + (i + const0) * sizeof(dtype)
...
mem[i + constN] -> mem_ptr + (i + constN) * sizeof(dtype)

Taking 2N x Add + 1N x Mul instructions in total.

Optimized Implementation

We can optimize the above pattern by compute mem_ptr + i * sizeof(dtype) first, then const-fold constN * sizeof(dtype) in each statment:

base_ptr = mem_ptr + i*sizeof(dtype)
mem[i + const0] -> base_ptr + const0' 
...
mem[i + constN] -> base_ptr + constN'

Where constN' is the const-fold result of constN * sizeof(dtype).

Now the optimized code takes 1 x Add + 1 x Mul + 1N x Add, which roughly saves 1N x Add + 1N x Mul with large N.

Real world example

The above pattern usually appears inside statically unrolled for-loops:

t_group = 4 
b_group = 8 
block_dim=192
Tmax = 768 
ti_matf = ti.types.matrix(b_group, t_group, dtype=float)
x = ti.field(ti.f32, shape=1)

@ti.kernel
def test():
    T = 40
    ti.loop_config(block_dim=block_dim)
    for t_block in ti.ndrange(T // t_group):
        t = t_block * t_group
        w_pad = ti.simt.block.SharedArray((Tmax,), ti.f32)
    
        ti.simt.block.sync()
        for u in range(0, t+1):
            for bi in ti.static(range(b_group)):
                for i in ti.static(range(t_group)):
                    x[0] = w_pad[(u+bi+i)]
        ti.simt.block.sync()

You'll noticed the repeated 2 x Add + 1 x Mul pattern as follow: 2022-07-19 18-31-55 的屏幕截图

jim19930609 avatar Jul 19 '22 10:07 jim19930609

Implementation Plan

There are two ways to achieve the optimized ptx code:

  1. Write a CHI IR pass to transform PtrOffsetStmt, an example of which is shown below. This will be the ultimate fallback solution.
  2. Write a LLVM pass to optimize the offset calculation across a group of GEPs, which will be a more generic optimization. There's already a pass doing similar job: https://llvm.org/doxygen/SeparateConstOffsetFromGEP_8cpp_source.html.

Will first try if SeparateConstOffsetFromGEP pass can be safely turned on without side effects, otherwise will fallback to the CHI IR Pass for PtrOffsetStmt.

Let me know if there's any suggestions or concerns. @turbo0628 @ailzhang

Current Implementation

[chi ir]
ptr0 = PtrOffsetStmt(array, AddStmt(LoopIndexStmt, 10))
ptr1 = PtrOffsetStmt(array, AddStmt(LoopIndexStmt, 11))
ptr2 = PtrOffsetStmt(array, AddStmt(LoopIndexStmt, 12))
ptr3 = PtrOffsetStmt(array, AddStmt(LoopIndexStmt, 13))

[llvm]
ptr0 = GEP(array, i + 10)
ptr1 = GEP(array, i + 11)
ptr2 = GEP(array, i + 12)
ptr3 = GEP(array, i + 13)

[ptx]
ptr0 = array + (i + 10) * 4
ptr1 = array + (i + 11) * 4
ptr2 = array + (i + 12) * 4
ptr3 = array + (i + 13) * 4

Optimized Implementation

[chi ir]
base = AddStmt(array, MulStmt(LoopIndexStmt, sizeof(dtype)))
ptr0 = PtrOffsetStmt(base, 10)
ptr1 = PtrOffsetStmt(base, 11)
ptr2 = PtrOffsetStmt(base, 12)
ptr3 = PtrOffsetStmt(base, 13)

[llvm]
// Only works on an array
base = GEP(array, i)
ptr0 = GEP(base, 10)
ptr1 = GEP(base, 11)
ptr2 = GEP(base, 12)
ptr3 = GEP(base, 13)

[ptx]
base = array + i  * 4
ptr0 = base  + 10 * 4
ptr1 = base  + 11 * 4
ptr2 = base  + 12 * 4
ptr3 = base  + 13 * 4

jim19930609 avatar Jul 20 '22 03:07 jim19930609

+1 that we should try to delegate to LLVM for this optimization if possible. Thanks for the detailed proposal! nit: the llvm code in optimized implementation should be

[llvm]
// Only works on an array
base = array + i * 4
ptr0 = GEP(base, 10)
ptr1 = GEP(base, 11)
ptr2 = GEP(base, 12)
ptr3 = GEP(base, 13)

ailzhang avatar Jul 20 '22 03:07 ailzhang

+1 that we should try to delegate to LLVM for this optimization if possible. Thanks for the detailed proposal! nit: the llvm code in optimized implementation should be

[llvm]
// Only works on an array
base = array + i * 4
ptr0 = GEP(base, 10)
ptr1 = GEP(base, 11)
ptr2 = GEP(base, 12)
ptr3 = GEP(base, 13)

Thanks!

jim19930609 avatar Jul 20 '22 04:07 jim19930609

It's great to try the LLVM pass for ad hoc implementation.

When the performance gain is verified, I think we need to port the pass onto the CHI-IR level as this is also valuable for SPIR-V targets.

turbo0628 avatar Jul 20 '22 06:07 turbo0628