torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[TORCH-LINALG] Add recomposition for select+copy_

Open bilibiliGO283 opened this issue 1 year ago • 2 comments

I am relatively new to the world of developing open-source projects and currently working on Torch-MLIR. It's possible that I might have overlooked certain processes or encountered errors in my code. I would be immensely grateful if someone could kindly offer their expertise to help me identify and address any potential issues.

issuse description

I found that when converting such a torch model, errors will occur:

class SelectAndCopy(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        new_x = x[0]
        new_x.copy_(y)
        return x

If RecomposeSelectCopy_ is not added, after several iterations, it will eventually result in this error:

error: unsupported by backend contract: tensor with unknown rank

This is the IR that failed to convert:

// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "SelectAndCopy"} {
  func.func @forward(%arg0: !torch.vtensor<[100,100],f32>, %arg1: !torch.vtensor<[100],f32>) -> !torch.vtensor {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[100,100],f32> to !torch.vtensor
    %1 = torch.copy.to_tensor %0 : !torch.tensor
    %2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,100],unk>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.tensor<[1,100],unk>, !torch.int -> !torch.tensor<[100],unk>
    %4 = torch.tensor_static_info_cast %3 : !torch.tensor<[100],unk> to !torch.tensor
    %5 = torch.copy.to_vtensor %4 : !torch.vtensor
    %6 = torch.aten.copy %5, %arg1, %false : !torch.vtensor, !torch.vtensor<[100],f32>, !torch.bool -> !torch.vtensor<[100],unk>
    %7 = torch.tensor_static_info_cast %6 : !torch.vtensor<[100],unk> to !torch.vtensor
    torch.overwrite.tensor.contents %7 overwrites %4 : !torch.vtensor, !torch.tensor
    %8 = torch.copy.to_vtensor %1 : !torch.vtensor
    return %8 : !torch.vtensor
  }
}

solution

The method of repair adopted the solution from https://github.com/llvm/torch-mlir/pull/2150. RecomposeSelectCopy_ converted the torch.aten.select.int operator into the torch.aten.slice operator. Afterwards, RecomposeSliceCopy_ further converted the slice+copy operator into the index_put operator. before executing RecomposeComplexOpsPass IR:

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward(%arg0: !torch.vtensor<[100,100],f32>, %arg1: !torch.vtensor<[10],f32>) -> (!torch.tensor, !torch.tensor) {
  %int0 = torch.constant.int 0
  %false = torch.constant.bool false
  %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[100,100],f32> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %2 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[10],f32> to !torch.vtensor
  %3 = torch.copy.to_tensor %2 : !torch.tensor
  %4 = torch.aten.select.int %1, %int0, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
  %5 = torch.aten.copy_ %4, %3, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
  return %1, %4 : !torch.tensor, !torch.tensor
}

The following is the IR transformed by RecomposeComplexOpsPass:

func.func @forward(%arg0: !torch.vtensor<[100,100],f32>, %arg1: !torch.vtensor<[100],f32>) -> !torch.tensor {
  %int1 = torch.constant.int 1
  %int0 = torch.constant.int 0
  %none = torch.constant.none
  %false = torch.constant.bool false
  %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[100,100],f32> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %2 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[100],f32> to !torch.vtensor
  %3 = torch.copy.to_tensor %2 : !torch.tensor
  %4 = torch.aten.arange.start_step %int0, %int1, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<[?],unk>
  %5 = torch.prim.ListConstruct %4 : (!torch.tensor<[?],unk>) -> !torch.list<optional<tensor>>
  %6 = torch.aten._index_put_impl_ %1, %5, %3, %false, %false : !torch.tensor, !torch.list<optional<tensor>>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor
  return %1 : !torch.tensor
}

The select + copy operator is transformed into an index_put operator.

bilibiliGO283 avatar Aug 02 '23 10:08 bilibiliGO283