torch-mlir
torch-mlir copied to clipboard
How to lower the pattern `torch.aten.Int.Scalar(torch.aten.item(rank0_tensor))`?
I want to lower torch.dialect module something like the following:
module {
func @main(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.number
%1 = torch.aten.Int.Scalar %0 : !torch.number -> !torch.int
%2 = torch.aten.add.Scalar %arg0, %1, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?,?],f32>
}
}
But I can't add a per-Op converter to somewhere like lib/Conversion/TorchToTosa/TorchToTosa.cpp, since there is no TypeConverter between !torch.number and !torch.int.
Is there anyone knows how to lower the IR module?
torch-refine-types pass should convert it to this:
module {
func.func @main(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int0 = torch.constant.int 0
%0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
%1 = torch.aten.Int.Scalar %0 : !torch.int -> !torch.int
%2 = torch.aten.add.Scalar %arg0, %1, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?,?],f32>
}
}
And then torch.aten.Int.Scalar should be folded away by a later pass such as canonicalizer.
It appears that torch-refine-types pass has a bug -- would you be able to try debugging it?