torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[Tosa backend] Support MatmulStaticBroadcast_basic for tosa

Open bilibiliGO283 opened this issue 1 year ago • 4 comments

This PR is to solve this issue: #2581

Test case

class Matmul(nn.Module):
    def __init__(self):
        super(Matmul, self).__init__()
 
    def forward(self, x, weight):
        r = torch.matmul(x, weight)
        return r

def gen_mt_mlir():
    model = Matmul()
    weight = torch.randn(80,300,250,150,dtype=torch.float32)
    input = torch.ones(300,100,250)
    res = model(input,weight)
    print(res.shape)
    module = torch_mlir.compile(model, [input,weight], output_type="TOSA")

if __name__=='__main__':
    gen_mt_mlir()

Before bug fix

func.func @forward(%arg0: !torch.vtensor<[300,100,250],f32>, %arg1: !torch.vtensor<[80,300,250,150],f32>) -> !torch.vtensor<[80,300,100,150],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[80,300,250,150],f32> to tensor<80x300x250x150xf32>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[300,100,250],f32> to tensor<300x100x250xf32>
  %2 = tosa.reshape %1 {new_shape = array<i64: 1, 300, 100, 250>} : (tensor<300x100x250xf32>) -> tensor<1x300x100x250xf32>
  %3 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %4 = tosa.transpose %2, %3 : (tensor<1x300x100x250xf32>, tensor<4xi32>) -> tensor<300x1x100x250xf32>
  %5 = tosa.reshape %4 {new_shape = array<i64: 300, 100, 250>} : (tensor<300x1x100x250xf32>) -> tensor<300x100x250xf32>
  %6 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %7 = tosa.transpose %0, %6 : (tensor<80x300x250x150xf32>, tensor<4xi32>) -> tensor<300x250x80x150xf32>
  %8 = tosa.reshape %7 {new_shape = array<i64: 300, 250, 12000>} : (tensor<300x250x80x150xf32>) -> tensor<300x250x12000xf32>
  %9 = tosa.matmul %5, %8 : (tensor<300x100x250xf32>, tensor<300x250x12000xf32>) -> tensor<300x100x12000xf32>
  %10 = tosa.reshape %9 {new_shape = array<i64: 300, 100, 80, 150>} : (tensor<300x100x12000xf32>) -> tensor<300x100x80x150xf32>
  %11 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %12 = tosa.transpose %10, %11 : (tensor<300x100x80x150xf32>, tensor<4xi32>) -> tensor<80x300x100x150xf32>
  %13 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[300,100,250],f32>, !torch.vtensor<[80,300,250,150],f32> -> !torch.vtensor<[80,300,100,150],f32>
  return %13 : !torch.vtensor<[80,300,100,150],f32>
}

Bug fix

func.func @forward(%arg0: !torch.vtensor<[300,100,250],f32>, %arg1: !torch.vtensor<[80,300,250,150],f32>) -> !torch.vtensor<[80,300,100,150],f32> {
  %0 = builtin.unrealized_conversion_cast %arg1 : !torch.vtensor<[80,300,250,150],f32> to tensor<80x300x250x150xf32>
  %1 = builtin.unrealized_conversion_cast %arg0 : !torch.vtensor<[300,100,250],f32> to tensor<300x100x250xf32>
  %2 = tosa.reshape %1 {new_shape = array<i64: 1, 300, 100, 250>} : (tensor<300x100x250xf32>) -> tensor<1x300x100x250xf32>
  %3 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %4 = tosa.transpose %2, %3 : (tensor<1x300x100x250xf32>, tensor<4xi32>) -> tensor<300x1x100x250xf32>
  %5 = tosa.reshape %4 {new_shape = array<i64: 300, 100, 250>} : (tensor<300x1x100x250xf32>) -> tensor<300x100x250xf32>
  %6 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %7 = tosa.transpose %0, %6 : (tensor<80x300x250x150xf32>, tensor<4xi32>) -> tensor<300x250x80x150xf32>
  %8 = tosa.reshape %7 {new_shape = array<i64: 300, 250, 12000>} : (tensor<300x250x80x150xf32>) -> tensor<300x250x12000xf32>
  %9 = tosa.matmul %5, %8 : (tensor<300x100x250xf32>, tensor<300x250x12000xf32>) -> tensor<300x100x12000xf32>
  %10 = tosa.reshape %9 {new_shape = array<i64: 300, 100, 80, 150>} : (tensor<300x100x12000xf32>) -> tensor<300x100x80x150xf32>
  %11 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %12 = tosa.transpose %10, %11 : (tensor<300x100x80x150xf32>, tensor<4xi32>) -> tensor<80x300x100x150xf32>
  %13 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[300,100,250],f32>, !torch.vtensor<[80,300,250,150],f32> -> !torch.vtensor<[80,300,100,150],f32>
  return %13 : !torch.vtensor<[80,300,100,150],f32>
}

Result analysis

there are some mistake on %11:

%11 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>

It was transformed into this:

%11 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>

bilibiliGO283 avatar Nov 24 '23 04:11 bilibiliGO283