torch-mlir
torch-mlir copied to clipboard
[Tosa backend] Support MatmulStaticBroadcast_basic for tosa
This PR is to solve this issue: #2581
Test case
class Matmul(nn.Module):
def __init__(self):
super(Matmul, self).__init__()
def forward(self, x, weight):
r = torch.matmul(x, weight)
return r
def gen_mt_mlir():
model = Matmul()
weight = torch.randn(80,300,250,150,dtype=torch.float32)
input = torch.ones(300,100,250)
res = model(input,weight)
print(res.shape)
module = torch_mlir.compile(model, [input,weight], output_type="TOSA")
if __name__=='__main__':
gen_mt_mlir()
Before bug fix
func.func @forward(%arg0: !torch.vtensor<[300,100,250],f32>, %arg1: !torch.vtensor<[80,300,250,150],f32>) -> !torch.vtensor<[80,300,100,150],f32> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[80,300,250,150],f32> to tensor<80x300x250x150xf32>
%1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[300,100,250],f32> to tensor<300x100x250xf32>
%2 = tosa.reshape %1 {new_shape = array<i64: 1, 300, 100, 250>} : (tensor<300x100x250xf32>) -> tensor<1x300x100x250xf32>
%3 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%4 = tosa.transpose %2, %3 : (tensor<1x300x100x250xf32>, tensor<4xi32>) -> tensor<300x1x100x250xf32>
%5 = tosa.reshape %4 {new_shape = array<i64: 300, 100, 250>} : (tensor<300x1x100x250xf32>) -> tensor<300x100x250xf32>
%6 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%7 = tosa.transpose %0, %6 : (tensor<80x300x250x150xf32>, tensor<4xi32>) -> tensor<300x250x80x150xf32>
%8 = tosa.reshape %7 {new_shape = array<i64: 300, 250, 12000>} : (tensor<300x250x80x150xf32>) -> tensor<300x250x12000xf32>
%9 = tosa.matmul %5, %8 : (tensor<300x100x250xf32>, tensor<300x250x12000xf32>) -> tensor<300x100x12000xf32>
%10 = tosa.reshape %9 {new_shape = array<i64: 300, 100, 80, 150>} : (tensor<300x100x12000xf32>) -> tensor<300x100x80x150xf32>
%11 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%12 = tosa.transpose %10, %11 : (tensor<300x100x80x150xf32>, tensor<4xi32>) -> tensor<80x300x100x150xf32>
%13 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[300,100,250],f32>, !torch.vtensor<[80,300,250,150],f32> -> !torch.vtensor<[80,300,100,150],f32>
return %13 : !torch.vtensor<[80,300,100,150],f32>
}
Bug fix
func.func @forward(%arg0: !torch.vtensor<[300,100,250],f32>, %arg1: !torch.vtensor<[80,300,250,150],f32>) -> !torch.vtensor<[80,300,100,150],f32> {
%0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[80,300,250,150],f32> to tensor<80x300x250x150xf32>
%1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[300,100,250],f32> to tensor<300x100x250xf32>
%2 = tosa.reshape %1 {new_shape = array<i64: 1, 300, 100, 250>} : (tensor<300x100x250xf32>) -> tensor<1x300x100x250xf32>
%3 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%4 = tosa.transpose %2, %3 : (tensor<1x300x100x250xf32>, tensor<4xi32>) -> tensor<300x1x100x250xf32>
%5 = tosa.reshape %4 {new_shape = array<i64: 300, 100, 250>} : (tensor<300x1x100x250xf32>) -> tensor<300x100x250xf32>
%6 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%7 = tosa.transpose %0, %6 : (tensor<80x300x250x150xf32>, tensor<4xi32>) -> tensor<300x250x80x150xf32>
%8 = tosa.reshape %7 {new_shape = array<i64: 300, 250, 12000>} : (tensor<300x250x80x150xf32>) -> tensor<300x250x12000xf32>
%9 = tosa.matmul %5, %8 : (tensor<300x100x250xf32>, tensor<300x250x12000xf32>) -> tensor<300x100x12000xf32>
%10 = tosa.reshape %9 {new_shape = array<i64: 300, 100, 80, 150>} : (tensor<300x100x12000xf32>) -> tensor<300x100x80x150xf32>
%11 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%12 = tosa.transpose %10, %11 : (tensor<300x100x80x150xf32>, tensor<4xi32>) -> tensor<80x300x100x150xf32>
%13 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[300,100,250],f32>, !torch.vtensor<[80,300,250,150],f32> -> !torch.vtensor<[80,300,100,150],f32>
return %13 : !torch.vtensor<[80,300,100,150],f32>
}
Result analysis
there are some mistake on %11:
%11 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
It was transformed into this:
%11 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>