circt
circt copied to clipboard
[AffineToPipeline] Unroll loops
Summary: MWE of loop unrolling + naive store-load forwarding in order to lower nested loops to pipeline dialect.
Details:
This is a really dirty sketch of the hoops I had to jump through in order to unroll this
func.func @forward(%arg0: memref<1x2xf32>) -> memref<1x2xf32> {
%3 = memref.alloca() : memref<2x2xf32>
%4 = memref.alloca() : memref<1x2xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 2 {
affine.for %arg3 = 0 to 2 {
%5 = affine.load %arg0[%arg1, %arg3] : memref<1x2xf32>
%6 = affine.load %3[%arg3, %arg2] : memref<2x2xf32>
%7 = affine.load %4[%arg1, %arg2] : memref<1x2xf32>
%8 = arith.mulf %5, %6 : f32
%9 = arith.addf %7, %8 : f32
affine.store %9, %4[%arg1, %arg2] : memref<1x2xf32>
}
}
}
return %4 : memref<1x2xf32>
}
and produce this
module {
func.func @matvecmul(%arg0: memref<1x2xi32>) -> memref<1x2xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloca() : memref<2x2xi32>
%1 = memref.alloca() : memref<1x2xi32>
%c0_0 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%c1_2 = arith.constant 1 : index
pipeline.while II = 1 trip_count = 1 iter_args(%arg1 = %c0_0) : (index) -> () {
%2 = arith.cmpi ult, %arg1, %c1_1 : index
pipeline.register %2 : i1
} do {
%2:9 = pipeline.while.stage start = 0 {
%5 = memref.load %arg0[%arg1, %c0] : memref<1x2xi32>
%6 = memref.load %0[%c0, %c0] : memref<2x2xi32>
%7 = memref.load %arg0[%arg1, %c1] : memref<1x2xi32>
%8 = memref.load %0[%c1, %c0] : memref<2x2xi32>
%9 = memref.load %arg0[%arg1, %c0] : memref<1x2xi32>
%10 = memref.load %0[%c0, %c1] : memref<2x2xi32>
%11 = memref.load %arg0[%arg1, %c1] : memref<1x2xi32>
%12 = memref.load %0[%c1, %c1] : memref<2x2xi32>
%13 = arith.addi %arg1, %c1_2 : index
pipeline.register %5, %6, %7, %8, %9, %10, %11, %12, %13 : i32, i32, i32, i32, i32, i32, i32, i32, index
} : i32, i32, i32, i32, i32, i32, i32, i32, index
%3:4 = pipeline.while.stage start = 1 {
%5 = arith.muli %2#0, %2#1 : i32
%6 = arith.muli %2#2, %2#3 : i32
%7 = arith.muli %2#4, %2#5 : i32
%8 = arith.muli %2#6, %2#7 : i32
pipeline.register %5, %6, %7, %8 : i32, i32, i32, i32
} : i32, i32, i32, i32
%4:2 = pipeline.while.stage start = 3 {
%5 = memref.load %1[%arg1, %c0] : memref<1x2xi32>
%6 = memref.load %1[%arg1, %c1] : memref<1x2xi32>
pipeline.register %5, %6 : i32, i32
} : i32, i32
pipeline.while.stage start = 4 {
%5 = arith.addi %4#0, %3#0 : i32
%6 = arith.addi %5, %3#1 : i32
memref.store %6, %1[%arg1, %c0] : memref<1x2xi32>
%7 = arith.addi %4#1, %3#2 : i32
%8 = arith.addi %7, %3#3 : i32
memref.store %8, %1[%arg1, %c1] : memref<1x2xi32>
pipeline.register
}
pipeline.terminator iter_args(%2#8), results() : (index) -> ()
}
return %1 : memref<1x2xi32>
}
}
The only real crucial part is forwardStoreToLoad, which is a braindead naive implementation of store-load forwarding (reverse iterate over the loads and find the first matching store). Note this unrolls to a one loop deep loop-nest (under the theory that a loop corresponds to a pipeline). It won't work for many (most?) IR patterns but it will work for MACs (particularly those corresponding to NNs).
I'm sure I did a lot of things wrong here because I'm a n00b (in particular I couldn't figure out how to run a nested CSE pass to clean up the constants). Happy to iterate.
Thanks @Groverkss for the tips on speeding up forwardStoreToLoad.
Addendum
The second commit increases the hackiness dramatically but enables lowering an entire small CNN model; roughly
class ConvPlusReLU(nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, bias=bias)
self.conv2 = torch.nn.Conv2d(out_channels, in_channels, 3, bias=bias)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
return x
The key hack is the stubbing out of
// auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();
The reason for this is that this analysis exactly negates the win of the implementation of forwardStoreToLoad, in that it performs exactly the same expensive memory dependence checks that the non-naive forwardStoreToLoad performs.
Incidentally, from what I can tell, that dependenceAnalysis isn't actually being used for anything.
Of the other two changes, one is a hack and the other is a question of design choice; the reason for
for (auto *op : group) {
for (auto *user : op->getUsers()) {
if (*problem.getStartTime(user) > startTime || isLoopTerminator(user)) {
opsWithReturns.insert(op);
stageTypes.append(op->getResultTypes().begin(),
op->getResultTypes().end());
+ break;
}
}
}
is that the last layer of the small_cnn has a load that's has two users:
affine.for %arg2 = 0 to 2 {
...
%9 = affine.load %7[%c0, %arg2, %c0, %c0] : memref<1x2x1x1xi32>
%10 = arith.cmpi ugt, %9, %cst : i32
%11 = arith.select %10, %9, %cst : i32
affine.store %11, %8[%arg1, %arg2, %arg3, %arg4] : memref<1x2x1x1xi32>
...
}
This (without the break;) shows up as a pipeline.register that registers two values but a pipeline.while.stage with four result types:
%24:5 = "pipeline.while.stage"() ({
%25 = "memref.load"(%16, %0, %0, %0, %0) : (memref<1x2x1x1xi32>, index, index, index, index) -> i32
%26 = "memref.load"(%16, %0, %1, %0, %0) : (memref<1x2x1x1xi32>, index, index, index, index) -> i32
%27 = "arith.addi"(%arg1, %23) : (index, index) -> index
"pipeline.register"(%25, %26, %27) : (i32, i32, index) -> ()
}) {start = 0 : si64} : () -> (i32, i32, i32, i32, index)
I do not know whether the right answer here is to "register twice" (once for each user) or have both users read from the same register.
Finally, the reason for the
// llvm::sort(group,
// [&](Operation *a, Operation *b) { return dom.dominates(a, b); });
is that without it I get null operands in various places:
%197 = "memref.load"(%arg0, %arg1, %1, <<NULL VALUE>>, %0) :
(memref<1x2x5x5xi32>, index, index, <<NULL TYPE>>, index) -> i32
I'm still debugging this one.
Thanks for pushing on this, this is awesome! I will take a look in a bit, but from a high level, this is something I've wanted to do. I'll take a look at the implementation in a bit.
Regarding the hacks, and the stuff you couldn't figure out, I will check out this branch and try to jog my memory. I'm sure we can clean them up. The end goal is awesome, but let's not merge something in till we resolve the hacky bits.
Status?
@darthscsi
Status?
I've kind of taken a detour from CIRCT but if you want/need this I can finish it up ~end of June.
I think @andrewb1999 and @matth2k have some ideas about this. Not sure if they had any efforts in this direction yet, but maybe this PR can be dusted off for that work (or closed if there is an alternative implementation). We briefly discussed on Discord adding some kind of pass to handle these transformations.