circt icon indicating copy to clipboard operation
circt copied to clipboard

Pipeline For Lowering PyTorch Tensor Add To FPGA

Open jiahanxie353 opened this issue 1 year ago • 0 comments

Overview

This patch is the first one of a sequence of future PRs that will try to support a complete pipeline from PyTorch Add operation to FPGA.

The lowering process consists of the following steps:

  1. Lowering from a PyTorch source program written in Python to MLIR;
  2. Lowering from the emitted MLIR program to CIRCT;
  3. Export System Verilog programs and try to simulate on FPGAs.

Breakdowns

Step 1: From PyTorch to MLIR

The source PyTorch program is a simple torch.add:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x, y):
        return x + y

This is done through Allo. Using example:

model = Model()
model.eval()
example_inputs = [torch.rand(1, 3, 10, 10), torch.rand(1, 3, 10, 10)]
mlir_mod = allo.frontend.from_pytorch(
    model, example_inputs=example_inputs, verbose=False, enable_tensor=True
)
print(mlir_mod)

we can emit MLIR:

module {
  func.func @forward(%arg0: tensor<1x3x10x10xf32>, %arg1: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> attributes {itypes = "__", otypes = "_"} {
    %0 = tensor.empty() : tensor<1x3x10x10xf32>
    %1 = linalg.add {name = "add", op_name = "add_0"} ins(%arg0, %arg1 : tensor<1x3x10x10xf32>, tensor<1x3x10x10xf32>) outs(%0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32>
    return %1 : tensor<1x3x10x10xf32>
  }
}

Step 2: Conversions in CIRCT

Before using the conversions in CIRCT, we need to first get rid of some high-level dialect operations, such as tensor.empty(), and linalg.add.

To this end, we can run mlir-opt torch-add.mlir --empty-tensor-to-alloc-tensor --one-shot-bufferize="allow-return-allocs-from-loops bufferize-function-boundaries" --buffer-results-to-out-params --convert-linalg-to-loops --canonicalize to emit the following code:

module {
  func.func @forward(%arg0: memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>, %arg1: memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>, %arg2: memref<1x3x10x10xf32>) attributes {itypes = "__", otypes = "_"} {
    %c10 = arith.constant 10 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x3x10x10xf32>
    scf.for %arg3 = %c0 to %c3 step %c1 {
      scf.for %arg4 = %c0 to %c10 step %c1 {
        scf.for %arg5 = %c0 to %c10 step %c1 {
          %0 = memref.load %arg0[%c0, %arg3, %arg4, %arg5] : memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>
          %1 = memref.load %arg1[%c0, %arg3, %arg4, %arg5] : memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>
          %2 = arith.addf %0, %1 : f32
          memref.store %2, %alloc[%c0, %arg3, %arg4, %arg5] : memref<1x3x10x10xf32>
        }
      }
    }
    memref.copy %alloc, %arg2 : memref<1x3x10x10xf32> to memref<1x3x10x10xf32>
    return
  }
}

Now, we can use the tools in CIRCT, which consists of the following steps:

  1. Lower SCF to Calyx;
  2. Lower Calyx to hw, seq, comb, and sv to get ready to generate System Verilog code.

Step 3: Generate System Verilog Program and Simulate on FPGAs

Patch Specific Works

During lowering SCF to Calyx, we don't have support for memref::CopyOp. Thanksfully, handshake gives us a pass named --handshake-legalize-memrefs to convert memref::CopyOp to memref::LoadOp and memref::StoreOp.

This patch makes a placeholder for memref::CopyOp in SCF-to-Calyx pass to remind future users.

jiahanxie353 avatar Apr 25 '24 18:04 jiahanxie353