iree icon indicating copy to clipboard operation
iree copied to clipboard

[compilation][cpu]: failed to legalize operation onnx.Multinomial

Open pdhirajkumarprasad opened this issue 6 months ago • 17 comments

What happened?

for the given IR

module {
  func.func @"torch-jit-export"( %arg6: !torch.vtensor<[?,4],f32>) -> (!torch.vtensor<[?,1],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.7"} {
    %82 = torch.operator "onnx.Multinomial"(%arg6) {torch.onnx.dtype = 7 : si64, torch.onnx.sample_size = 1 : si64} : (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,1],si64> 
    return %82: !torch.vtensor<[?,1],si64>
  }
}

getting following error

t1.mlir:3:11: error: 'arith.addf' op requires the same type for all operands and results
    %82 = torch.operator "onnx.Multinomial"(%arg6) {torch.onnx.dtype = 7 : si64, torch.onnx.sample_size = 1 : si64} : (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,1],si64> 
          ^
// -----// IR Dump After ConvertTorchToLinalg Failed (convert-torch-to-linalg) //----- //
"func.func"() <{function_type = (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,1],si64>, sym_name = "torch-jit-export"}> ({
^bb0(%arg0: !torch.vtensor<[?,4],f32> loc("t1.mlir":2:34)):
  %0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[?,4],f32>) -> tensor<?x4xf32> loc("t1.mlir":3:11)
  %1 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int loc("t1.mlir":3:11)
  %2 = "arith.constant"() <{value = 1 : i64}> : () -> i64 loc("t1.mlir":3:11)
  %3 = "torch.constant.none"() : () -> !torch.none loc("t1.mlir":3:11)
  %4 = "torch.constant.bool"() <{value = true}> : () -> !torch.bool loc("t1.mlir":3:11)
  %5 = "arith.constant"() <{value = 0 : i64}> : () -> i64 loc("t1.mlir":3:11)
  %6 = "arith.constant"() <{value = 1 : i64}> : () -> i64 loc("t1.mlir":3:11)
  %7 = "arith.constant"() <{value = 0 : index}> : () -> index loc("t1.mlir":3:11)
  %8 = "arith.constant"() <{value = 1 : index}> : () -> index loc("t1.mlir":3:11)
  %9 = "arith.index_cast"(%2) : (i64) -> index loc("t1.mlir":3:11)
  %10 = "tensor.dim"(%0, %7) : (tensor<?x4xf32>, index) -> index loc("t1.mlir":3:11)
  %11 = "tensor.dim"(%0, %8) : (tensor<?x4xf32>, index) -> index loc("t1.mlir":3:11)

IREE Version: IREE compiler version 20240819.990 @ aeda14995f16ed1302db616adf0c03acf80f27ee LLVM version 20.0.0git

Steps to reproduce your issue

Command to reproduce the issue:

iree-compile --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=llvm-cpu model.torch_onnx.mlir

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

pdhirajkumarprasad avatar Aug 19 '24 03:08 pdhirajkumarprasad