torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

How to lower the pattern `torch.aten.Int.Scalar(torch.aten.item(rank0_tensor))`?

Open tanyokwok opened this issue 3 years ago • 1 comments

I want to lower torch.dialect module something like the following:

module {                                         
  func @main(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
    %int1 = torch.constant.int 1                          
    %int3 = torch.constant.int 3                                                                                         
    %int2 = torch.constant.int 2                                                               
    %int0 = torch.constant.int 0                                                                                                          
    %0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.number                            
    %1 = torch.aten.Int.Scalar %0 : !torch.number -> !torch.int
    %2 = torch.aten.add.Scalar %arg0, %1, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
    return %2 : !torch.vtensor<[?,?,?,?],f32>                         
  }                                                         
} 

But I can't add a per-Op converter to somewhere like lib/Conversion/TorchToTosa/TorchToTosa.cpp, since there is no TypeConverter between !torch.number and !torch.int.

Is there anyone knows how to lower the IR module?

tanyokwok avatar Jun 27 '22 10:06 tanyokwok

torch-refine-types pass should convert it to this:

module {
  func.func @main(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[?,?,?,?],f32> {
    %int1 = torch.constant.int 1
    %int3 = torch.constant.int 3
    %int2 = torch.constant.int 2
    %int0 = torch.constant.int 0
    %0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
    %1 = torch.aten.Int.Scalar %0 : !torch.int -> !torch.int
    %2 = torch.aten.add.Scalar %arg0, %1, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
    return %2 : !torch.vtensor<[?,?,?,?],f32>
  }
}

And then torch.aten.Int.Scalar should be folded away by a later pass such as canonicalizer.

It appears that torch-refine-types pass has a bug -- would you be able to try debugging it?

silvasean avatar Jun 30 '22 00:06 silvasean