iree
iree copied to clipboard
[compilation][cpu]: failed to legalize operation onnx.Multinomial
What happened?
for the given IR
module {
func.func @"torch-jit-export"( %arg6: !torch.vtensor<[?,4],f32>) -> (!torch.vtensor<[?,1],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.7"} {
%82 = torch.operator "onnx.Multinomial"(%arg6) {torch.onnx.dtype = 7 : si64, torch.onnx.sample_size = 1 : si64} : (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,1],si64>
return %82: !torch.vtensor<[?,1],si64>
}
}
getting following error
t1.mlir:3:11: error: 'arith.addf' op requires the same type for all operands and results
%82 = torch.operator "onnx.Multinomial"(%arg6) {torch.onnx.dtype = 7 : si64, torch.onnx.sample_size = 1 : si64} : (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,1],si64>
^
// -----// IR Dump After ConvertTorchToLinalg Failed (convert-torch-to-linalg) //----- //
"func.func"() <{function_type = (!torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,1],si64>, sym_name = "torch-jit-export"}> ({
^bb0(%arg0: !torch.vtensor<[?,4],f32> loc("t1.mlir":2:34)):
%0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[?,4],f32>) -> tensor<?x4xf32> loc("t1.mlir":3:11)
%1 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int loc("t1.mlir":3:11)
%2 = "arith.constant"() <{value = 1 : i64}> : () -> i64 loc("t1.mlir":3:11)
%3 = "torch.constant.none"() : () -> !torch.none loc("t1.mlir":3:11)
%4 = "torch.constant.bool"() <{value = true}> : () -> !torch.bool loc("t1.mlir":3:11)
%5 = "arith.constant"() <{value = 0 : i64}> : () -> i64 loc("t1.mlir":3:11)
%6 = "arith.constant"() <{value = 1 : i64}> : () -> i64 loc("t1.mlir":3:11)
%7 = "arith.constant"() <{value = 0 : index}> : () -> index loc("t1.mlir":3:11)
%8 = "arith.constant"() <{value = 1 : index}> : () -> index loc("t1.mlir":3:11)
%9 = "arith.index_cast"(%2) : (i64) -> index loc("t1.mlir":3:11)
%10 = "tensor.dim"(%0, %7) : (tensor<?x4xf32>, index) -> index loc("t1.mlir":3:11)
%11 = "tensor.dim"(%0, %8) : (tensor<?x4xf32>, index) -> index loc("t1.mlir":3:11)
IREE Version: IREE compiler version 20240819.990 @ aeda14995f16ed1302db616adf0c03acf80f27ee LLVM version 20.0.0git
Steps to reproduce your issue
Command to reproduce the issue:
iree-compile --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=llvm-cpu model.torch_onnx.mlir
What component(s) does this issue relate to?
Compiler
Version information
No response
Additional context
No response