torch-mlir
torch-mlir copied to clipboard
[TORCH-LINALG] Add recomposition for select+copy_
I am relatively new to the world of developing open-source projects and currently working on Torch-MLIR. It's possible that I might have overlooked certain processes or encountered errors in my code. I would be immensely grateful if someone could kindly offer their expertise to help me identify and address any potential issues.
issuse description
I found that when converting such a torch model, errors will occur:
class SelectAndCopy(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
new_x = x[0]
new_x.copy_(y)
return x
If RecomposeSelectCopy_ is not added, after several iterations, it will eventually result in this error:
error: unsupported by backend contract: tensor with unknown rank
This is the IR that failed to convert:
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "SelectAndCopy"} {
func.func @forward(%arg0: !torch.vtensor<[100,100],f32>, %arg1: !torch.vtensor<[100],f32>) -> !torch.vtensor {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[100,100],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,100],unk>
%3 = torch.aten.squeeze.dim %2, %int0 : !torch.tensor<[1,100],unk>, !torch.int -> !torch.tensor<[100],unk>
%4 = torch.tensor_static_info_cast %3 : !torch.tensor<[100],unk> to !torch.tensor
%5 = torch.copy.to_vtensor %4 : !torch.vtensor
%6 = torch.aten.copy %5, %arg1, %false : !torch.vtensor, !torch.vtensor<[100],f32>, !torch.bool -> !torch.vtensor<[100],unk>
%7 = torch.tensor_static_info_cast %6 : !torch.vtensor<[100],unk> to !torch.vtensor
torch.overwrite.tensor.contents %7 overwrites %4 : !torch.vtensor, !torch.tensor
%8 = torch.copy.to_vtensor %1 : !torch.vtensor
return %8 : !torch.vtensor
}
}
solution
The method of repair adopted the solution from https://github.com/llvm/torch-mlir/pull/2150.
RecomposeSelectCopy_
converted the torch.aten.select.int
operator into the torch.aten.slice
operator. Afterwards, RecomposeSliceCopy_
further converted the slice+copy operator into the index_put operator.
before executing RecomposeComplexOpsPass
IR:
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
func.func @forward(%arg0: !torch.vtensor<[100,100],f32>, %arg1: !torch.vtensor<[10],f32>) -> (!torch.tensor, !torch.tensor) {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[100,100],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[10],f32> to !torch.vtensor
%3 = torch.copy.to_tensor %2 : !torch.tensor
%4 = torch.aten.select.int %1, %int0, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
%5 = torch.aten.copy_ %4, %3, %false : !torch.tensor, !torch.tensor, !torch.bool -> !torch.tensor
return %1, %4 : !torch.tensor, !torch.tensor
}
The following is the IR transformed by RecomposeComplexOpsPass
:
func.func @forward(%arg0: !torch.vtensor<[100,100],f32>, %arg1: !torch.vtensor<[100],f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[100,100],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[100],f32> to !torch.vtensor
%3 = torch.copy.to_tensor %2 : !torch.tensor
%4 = torch.aten.arange.start_step %int0, %int1, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<[?],unk>
%5 = torch.prim.ListConstruct %4 : (!torch.tensor<[?],unk>) -> !torch.list<optional<tensor>>
%6 = torch.aten._index_put_impl_ %1, %5, %3, %false, %false : !torch.tensor, !torch.list<optional<tensor>>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor
return %1 : !torch.tensor
}
The select + copy operator is transformed into an index_put operator.