circt
circt copied to clipboard
Pipeline For Lowering PyTorch Tensor Add To FPGA
Overview
This patch is the first one of a sequence of future PRs that will try to support a complete pipeline from PyTorch Add operation to FPGA.
The lowering process consists of the following steps:
- Lowering from a PyTorch source program written in Python to MLIR;
- Lowering from the emitted MLIR program to CIRCT;
- Export System Verilog programs and try to simulate on FPGAs.
Breakdowns
Step 1: From PyTorch to MLIR
The source PyTorch program is a simple torch.add:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, y):
return x + y
This is done through Allo. Using example:
model = Model()
model.eval()
example_inputs = [torch.rand(1, 3, 10, 10), torch.rand(1, 3, 10, 10)]
mlir_mod = allo.frontend.from_pytorch(
model, example_inputs=example_inputs, verbose=False, enable_tensor=True
)
print(mlir_mod)
we can emit MLIR:
module {
func.func @forward(%arg0: tensor<1x3x10x10xf32>, %arg1: tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32> attributes {itypes = "__", otypes = "_"} {
%0 = tensor.empty() : tensor<1x3x10x10xf32>
%1 = linalg.add {name = "add", op_name = "add_0"} ins(%arg0, %arg1 : tensor<1x3x10x10xf32>, tensor<1x3x10x10xf32>) outs(%0 : tensor<1x3x10x10xf32>) -> tensor<1x3x10x10xf32>
return %1 : tensor<1x3x10x10xf32>
}
}
Step 2: Conversions in CIRCT
Before using the conversions in CIRCT, we need to first get rid of some high-level dialect operations, such as tensor.empty(), and linalg.add.
To this end, we can run mlir-opt torch-add.mlir --empty-tensor-to-alloc-tensor --one-shot-bufferize="allow-return-allocs-from-loops bufferize-function-boundaries" --buffer-results-to-out-params --convert-linalg-to-loops --canonicalize to emit the following code:
module {
func.func @forward(%arg0: memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>, %arg1: memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>, %arg2: memref<1x3x10x10xf32>) attributes {itypes = "__", otypes = "_"} {
%c10 = arith.constant 10 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x3x10x10xf32>
scf.for %arg3 = %c0 to %c3 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 {
%0 = memref.load %arg0[%c0, %arg3, %arg4, %arg5] : memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>
%1 = memref.load %arg1[%c0, %arg3, %arg4, %arg5] : memref<1x3x10x10xf32, strided<[?, ?, ?, ?], offset: ?>>
%2 = arith.addf %0, %1 : f32
memref.store %2, %alloc[%c0, %arg3, %arg4, %arg5] : memref<1x3x10x10xf32>
}
}
}
memref.copy %alloc, %arg2 : memref<1x3x10x10xf32> to memref<1x3x10x10xf32>
return
}
}
Now, we can use the tools in CIRCT, which consists of the following steps:
- Lower SCF to Calyx;
- Lower Calyx to
hw,seq,comb, andsvto get ready to generate System Verilog code.
Step 3: Generate System Verilog Program and Simulate on FPGAs
Patch Specific Works
During lowering SCF to Calyx, we don't have support for memref::CopyOp. Thanksfully, handshake gives us a pass named --handshake-legalize-memrefs to convert memref::CopyOp to memref::LoadOp and memref::StoreOp.
This patch makes a placeholder for memref::CopyOp in SCF-to-Calyx pass to remind future users.