allo icon indicating copy to clipboard operation
allo copied to clipboard

[BUG] Excessive Copy Loops

Open matth2k opened this issue 1 year ago • 5 comments

Describe the bug Excessive copy loops are created due to data type conversion of tensors expressed in the linalg dialect.

To Reproduce

def test_vadd():
    from allo import add

    def kernel(A: uint32[N], B: uint32[N]) -> uint32[N]:
        return A + B

    s = allo.customize(kernel)
    print(s.module)

Buggy output

#map = affine_map<(d0) -> (d0)>
module {
  func.func @kernel(%arg0: memref<20xi32>, %arg1: memref<20xi32>) -> memref<20xi32> attributes {itypes = "uu", otypes = "u"} {
    %alloc = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : memref<20xi32>) outs(%alloc : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_0 = memref.alloc() {unsigned} : memref<20xi33>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg1 : memref<20xi32>) outs(%alloc_0 : memref<20xi33>) {
    ^bb0(%in: i32, %out: i33):
      %0 = arith.extui %in : i32 to i33
      linalg.yield %0 : i33
    }
    %alloc_1 = memref.alloc() : memref<20xi33>
    linalg.add {op_name = "add_0"} ins(%alloc, %alloc_0 : memref<20xi33>, memref<20xi33>) outs(%alloc_1 : memref<20xi33>)
    %alloc_2 = memref.alloc() {unsigned} : memref<20xi32>
    linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%alloc_1 : memref<20xi33>) outs(%alloc_2 : memref<20xi32>) {
    ^bb0(%in: i33, %out: i32):
      %0 = arith.trunci %in : i33 to i32
      linalg.yield %0 : i32
    }
    return %alloc_2 : memref<20xi32>
  }
}

In short, when this is lowered to affine it manifests in excessive copying in the beginning of the program, and our AMC flow is very sensitive to this.

What really should occur is noticing that value that addition is bound to is the same as the input type. So just make add (i32, i32) -> i32 with normal wraparound.

matth2k avatar Dec 12 '23 23:12 matth2k

Thanks for bringing this issue up. Allo has a strong type system, and that's why it requires to guarantee the intermediate results will not overflow. I think either we can (1) test the input and output data types and bypass the type extension rule if the types align; or (2) fuse those linalg operations into one.

As a workaround, you can explicitly traverse each element in the arrays using for loops so no linalg operations will be built.

chhzh123 avatar Dec 13 '23 00:12 chhzh123

Your explanation makes sense @chhzh123 but I wonder what optimization HLS is doing to avoid this issue. Does it just fully unroll the copy loops so the extension and truncation can be no cost? If that's the case, maybe we can do a similar optimization in AMC to avoid this issue altogether.

andrewb1999 avatar Dec 14 '23 16:12 andrewb1999

I think Vivado/Vitis HLS only unrolls loops with small loop bounds. Otherwise, we need to explicitly write an unroll pragma to inform HLS. However, unrolling may incur excessive resource usage. The best way I think is still fusing the loops into one.

chhzh123 avatar Dec 15 '23 01:12 chhzh123

I think the main hiccup is that we are lowering to linalg, which is less expressive than imperative programs. So we have to extend both input vectors to int33 first, then add them, and finally truncate back to int32. To clean things up, we really need an extra pass to remove the unnecessary extend and truncate. Another option is insert in a primitive to fuse the loops so another optimization pass at a lower level can finish the job. This is not a good solution though.

zhangzhiru avatar Dec 15 '23 01:12 zhangzhiru

I have the fix implemented within our AMC backend. But I will eventually come up with a more universal solution and submit it as a separate PR to Allo.

matth2k avatar Jan 14 '24 22:01 matth2k