torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Reflect/Edge/Wrap pad lowering fails with `failed to legalize operation 'torch.prim.ListConstruct'`

Open TinaAMD opened this issue 8 months ago • 0 comments

Issue

I encounter an error when lowering a small onnx model that uses a pad with reflect mode to linalg.

The error is the following:

error: failed to legalize operation 'torch.prim.ListConstruct'
    %16 = torch.prim.ListConstruct %7, %15, %5, %13, %3, %11, %1, %9 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>

Reproduce with

torch-mlir-opt --torch-backend-to-linalg-on-tensors-backend-pipeline repro.mlir

where repro.mlir is

module {
  func.func @test_reflect_pad(%arg0: !torch.vtensor<[1,3,4,5],si32>, %arg1: !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],si32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %int8 = torch.constant.int 8
    %str = torch.constant.str "reflect"
    %int7 = torch.constant.int 7
    %int6 = torch.constant.int 6
    %int5 = torch.constant.int 5
    %int4 = torch.constant.int 4
    %int3 = torch.constant.int 3
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %none = torch.constant.none
    %int0 = torch.constant.int 0
    %0 = torch.aten.slice.Tensor %arg1, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
    %2 = torch.aten.slice.Tensor %arg1, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
    %4 = torch.aten.slice.Tensor %arg1, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %5 = torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int
    %6 = torch.aten.slice.Tensor %arg1, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %7 = torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int
    %8 = torch.aten.slice.Tensor %arg1, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %9 = torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int
    %10 = torch.aten.slice.Tensor %arg1, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int
    %12 = torch.aten.slice.Tensor %arg1, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
    %14 = torch.aten.slice.Tensor %arg1, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
    %16 = torch.prim.ListConstruct %7, %15, %5, %13, %3, %11, %1, %9 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %17 = torch.aten.pad %arg0, %16, %str, %none : !torch.vtensor<[1,3,4,5],si32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[1,3,6,7],si32>
    return %17 : !torch.vtensor<[1,3,6,7],si32>
  }
}

Tested with b3abd5666aa3fb1637f55c6bdce135c3dc84bbd9.

We see the same issues with the edge and wrap modes.

Expected outcome

Lowering to linalg is successful.

How the reproducer was generated

Onnx Model in text form (you can find this model (with a slightly newer opset) in the onnx backend tests):

<
   ir_version: 10,
   opset_import: ["" : 21],
   producer_name: "backend-test"
>
test_reflect_pad (int32[1,3,4,5] x, int64[8] pads) => (int32[1,3,6,7] y) {
   y = Pad <mode: string = "reflect"> (x, pads)
}

Steps to get the reproducer IR:

python -m torch_mlir.tools.import_onnx <file.onnx> &> importedonnx.mlir
torch-mlir-opt --torch-onnx-to-torch-backend-pipeline importedonnx.mlir &> repro.mlir

The importedonnx.mlir looks like this:

module {
  func.func @test_reflect_pad(%arg0: !torch.vtensor<[1,3,4,5],si32>, %arg1: !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],si32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,3,4,5],si32>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],si32>
    return %0 : !torch.vtensor<[1,3,6,7],si32>
  }
}

TinaAMD avatar Apr 09 '25 15:04 TinaAMD