circt
circt copied to clipboard
Transform Flatten Memref Load
Caught a small issue with flattening the load.
Say we have an example:
module {
func.func @main(%arg0: memref<3x4xi32>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<4x2xi32>
scf.for %arg1 = %c0 to %c3 step %c1 {
scf.for %arg2 = %c0 to %c2 step %c1 {
scf.for %arg3 = %c0 to %c4 step %c1 {
%0 = memref.load %arg0[%arg1, %arg3] : memref<3x4xi32>
%1 = memref.load %alloc[%arg3, %arg2] : memref<4x2xi32>
%2 = arith.muli %1, %2 : i32
}
}
}
return
}
}
Previously, it gives:
module {
func.func @main(%arg0: memref<12xi32>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<8xi32>
scf.for %arg1 = %c0 to %c3 step %c1 {
scf.for %arg2 = %c0 to %c2 step %c1 {
scf.for %arg3 = %c0 to %c4 step %c1 {
%c3_1 = arith.constant 3 : index
%1 = arith.muli %arg3, %c3_1 : index
%2 = arith.addi %arg1, %1 : index
%3 = memref.load %arg0[%2] : memref<12xi32>
%c2_2 = arith.constant 2 : index
%4 = arith.shli %arg2, %c2_2 : index
%5 = arith.addi %arg3, %4 : index
%6 = memref.load %alloc[%5] : memref<8xi32>
%7 = arith.muli %3, %6 : i32
}
}
}
return
}
}
It is wrong because it's doing:
%1 = arith.muli %arg3, %c3_1 : index
%2 = arith.addi %arg1, %1 : index
And the order of accessing the memory is:
%arg1 = 0, %arg3 = 0 -> access 0 * 3 + 0 = 0;
%arg1 = 0, %arg3 = 1 -> access 1 * 3 + 0 = 0; (oops, wrong because we should access address 1)
What it should be, which is also the result after fixing it, is:
module {
func.func @main(%arg0: memref<12xi32>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<8xi32>
scf.for %arg1 = %c0 to %c3 step %c1 {
scf.for %arg2 = %c0 to %c2 step %c1 {
scf.for %arg3 = %c0 to %c4 step %c1 {
%c2_1 = arith.constant 2 : index
%1 = arith.shli %arg1, %c2_1 : index
%2 = arith.addi %1, %arg3 : index
%3 = memref.load %arg0[%2] : memref<12xi32>
%c1_2 = arith.constant 1 : index
%4 = arith.shli %arg3, %c1_2 : index
%5 = arith.addi %4, %arg2 : index
%6 = memref.load %alloc[%5] : memref<8xi32>
%7 = arith.muli %3, %6 : i32
}
}
}
return
}
}