circt icon indicating copy to clipboard operation
circt copied to clipboard

[AffineToPipeline] Unroll loops

Open makslevental opened this issue 1 year ago • 1 comments

Summary: MWE of loop unrolling + naive store-load forwarding in order to lower nested loops to pipeline dialect.

Details:

This is a really dirty sketch of the hoops I had to jump through in order to unroll this

func.func @forward(%arg0: memref<1x2xf32>) -> memref<1x2xf32> {
  %3 = memref.alloca() : memref<2x2xf32>
  %4 = memref.alloca() : memref<1x2xf32>
  affine.for %arg1 = 0 to 1 {
    affine.for %arg2 = 0 to 2 {
      affine.for %arg3 = 0 to 2 {
        %5 = affine.load %arg0[%arg1, %arg3] : memref<1x2xf32>
        %6 = affine.load %3[%arg3, %arg2] : memref<2x2xf32>
        %7 = affine.load %4[%arg1, %arg2] : memref<1x2xf32>
        %8 = arith.mulf %5, %6 : f32
        %9 = arith.addf %7, %8 : f32
        affine.store %9, %4[%arg1, %arg2] : memref<1x2xf32>
      }
    }
  }
  return %4 : memref<1x2xf32>
}

and produce this

module {
  func.func @matvecmul(%arg0: memref<1x2xi32>) -> memref<1x2xi32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %0 = memref.alloca() : memref<2x2xi32>
    %1 = memref.alloca() : memref<1x2xi32>
    %c0_0 = arith.constant 0 : index
    %c1_1 = arith.constant 1 : index
    %c1_2 = arith.constant 1 : index
    pipeline.while II =  1 trip_count =  1 iter_args(%arg1 = %c0_0) : (index) -> () {
      %2 = arith.cmpi ult, %arg1, %c1_1 : index
      pipeline.register %2 : i1
    } do {
      %2:9 = pipeline.while.stage start = 0 {
        %5 = memref.load %arg0[%arg1, %c0] : memref<1x2xi32>
        %6 = memref.load %0[%c0, %c0] : memref<2x2xi32>
        %7 = memref.load %arg0[%arg1, %c1] : memref<1x2xi32>
        %8 = memref.load %0[%c1, %c0] : memref<2x2xi32>
        %9 = memref.load %arg0[%arg1, %c0] : memref<1x2xi32>
        %10 = memref.load %0[%c0, %c1] : memref<2x2xi32>
        %11 = memref.load %arg0[%arg1, %c1] : memref<1x2xi32>
        %12 = memref.load %0[%c1, %c1] : memref<2x2xi32>
        %13 = arith.addi %arg1, %c1_2 : index
        pipeline.register %5, %6, %7, %8, %9, %10, %11, %12, %13 : i32, i32, i32, i32, i32, i32, i32, i32, index
      } : i32, i32, i32, i32, i32, i32, i32, i32, index
      %3:4 = pipeline.while.stage start = 1 {
        %5 = arith.muli %2#0, %2#1 : i32
        %6 = arith.muli %2#2, %2#3 : i32
        %7 = arith.muli %2#4, %2#5 : i32
        %8 = arith.muli %2#6, %2#7 : i32
        pipeline.register %5, %6, %7, %8 : i32, i32, i32, i32
      } : i32, i32, i32, i32
      %4:2 = pipeline.while.stage start = 3 {
        %5 = memref.load %1[%arg1, %c0] : memref<1x2xi32>
        %6 = memref.load %1[%arg1, %c1] : memref<1x2xi32>
        pipeline.register %5, %6 : i32, i32
      } : i32, i32
      pipeline.while.stage start = 4 {
        %5 = arith.addi %4#0, %3#0 : i32
        %6 = arith.addi %5, %3#1 : i32
        memref.store %6, %1[%arg1, %c0] : memref<1x2xi32>
        %7 = arith.addi %4#1, %3#2 : i32
        %8 = arith.addi %7, %3#3 : i32
        memref.store %8, %1[%arg1, %c1] : memref<1x2xi32>
        pipeline.register
      }
      pipeline.terminator iter_args(%2#8), results() : (index) -> ()
    }
    return %1 : memref<1x2xi32>
  }
}

The only real crucial part is forwardStoreToLoad, which is a braindead naive implementation of store-load forwarding (reverse iterate over the loads and find the first matching store). Note this unrolls to a one loop deep loop-nest (under the theory that a loop corresponds to a pipeline). It won't work for many (most?) IR patterns but it will work for MACs (particularly those corresponding to NNs).

I'm sure I did a lot of things wrong here because I'm a n00b (in particular I couldn't figure out how to run a nested CSE pass to clean up the constants). Happy to iterate.

Thanks @Groverkss for the tips on speeding up forwardStoreToLoad.

Addendum

The second commit increases the hackiness dramatically but enables lowering an entire small CNN model; roughly

class ConvPlusReLU(nn.Module):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, bias=bias)
        self.conv2 = torch.nn.Conv2d(out_channels, in_channels, 3, bias=bias)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x

Final IR.

The key hack is the stubbing out of

// auto dependenceAnalysis = getAnalysis<MemoryDependenceAnalysis>();

The reason for this is that this analysis exactly negates the win of the implementation of forwardStoreToLoad, in that it performs exactly the same expensive memory dependence checks that the non-naive forwardStoreToLoad performs. Incidentally, from what I can tell, that dependenceAnalysis isn't actually being used for anything.

Of the other two changes, one is a hack and the other is a question of design choice; the reason for

    for (auto *op : group) {
      for (auto *user : op->getUsers()) {
        if (*problem.getStartTime(user) > startTime || isLoopTerminator(user)) {
          opsWithReturns.insert(op);
          stageTypes.append(op->getResultTypes().begin(),
                            op->getResultTypes().end());
+          break;
        }
      }
    }

is that the last layer of the small_cnn has a load that's has two users:

      affine.for %arg2 = 0 to 2 {
        ...
            %9 = affine.load %7[%c0, %arg2, %c0, %c0] : memref<1x2x1x1xi32>
            %10 = arith.cmpi ugt, %9, %cst : i32
            %11 = arith.select %10, %9, %cst : i32
            affine.store %11, %8[%arg1, %arg2, %arg3, %arg4] : memref<1x2x1x1xi32>
        ...
      }

This (without the break;) shows up as a pipeline.register that registers two values but a pipeline.while.stage with four result types:

    %24:5 = "pipeline.while.stage"() ({
      %25 = "memref.load"(%16, %0, %0, %0, %0) : (memref<1x2x1x1xi32>, index, index, index, index) -> i32
      %26 = "memref.load"(%16, %0, %1, %0, %0) : (memref<1x2x1x1xi32>, index, index, index, index) -> i32
      %27 = "arith.addi"(%arg1, %23) : (index, index) -> index
      "pipeline.register"(%25, %26, %27) : (i32, i32, index) -> ()
    }) {start = 0 : si64} : () -> (i32, i32, i32, i32, index)

I do not know whether the right answer here is to "register twice" (once for each user) or have both users read from the same register.

Finally, the reason for the

//    llvm::sort(group,
//               [&](Operation *a, Operation *b) { return dom.dominates(a, b); });

is that without it I get null operands in various places:

%197 = "memref.load"(%arg0, %arg1, %1, <<NULL VALUE>>, %0) :
 (memref<1x2x5x5xi32>, index, index, <<NULL TYPE>>, index) -> i32

I'm still debugging this one.

makslevental avatar Sep 21 '22 22:09 makslevental

Thanks for pushing on this, this is awesome! I will take a look in a bit, but from a high level, this is something I've wanted to do. I'll take a look at the implementation in a bit.

Regarding the hacks, and the stuff you couldn't figure out, I will check out this branch and try to jog my memory. I'm sure we can clean them up. The end goal is awesome, but let's not merge something in till we resolve the hacky bits.

mikeurbach avatar Sep 22 '22 19:09 mikeurbach

Status?

darthscsi avatar Apr 28 '23 18:04 darthscsi

@darthscsi

Status?

I've kind of taken a detour from CIRCT but if you want/need this I can finish it up ~end of June.

makslevental avatar Apr 28 '23 20:04 makslevental

I think @andrewb1999 and @matth2k have some ideas about this. Not sure if they had any efforts in this direction yet, but maybe this PR can be dusted off for that work (or closed if there is an alternative implementation). We briefly discussed on Discord adding some kind of pass to handle these transformations.

mikeurbach avatar May 01 '23 16:05 mikeurbach