torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

TorchToLinalg: casting float to integer should round to nearest

Open bjacob opened this issue 9 months ago • 3 comments

This comes from debugging a IREE ONNX test suite failure: https://github.com/iree-org/iree/actions/runs/13796654893/job/38590777349#step:8:72

The failure message is:

 [FAILED] result[0]: element at index 1 (31) does not match the expected (32); expected that the view is equal to contents of a view of 3xi32
  expected:
3xi32=1 32 729
  actual:
3xi32=1 31 729

Notice: 32 != 31.

The test linked from that failure is: https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops/onnx/node/generated/test_pow_types_int32_float32

Its source code is: https://github.com/iree-org/iree-test-suites/blob/main/onnx_ops/onnx/node/generated/test_pow_types_int32_float32/model.mlir

The relevant op is:

    %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3],si32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],si32> 

The relevant aspect of it is that the return type is integral, but the op internally expands to a math.powf which produces a floating-point value which needs to be casted to an integer type.

// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
func.func @test_pow_types_int32_float32(%arg0: !torch.vtensor<[3],si32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],si32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],f32> -> tensor<3xf32>
  %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3],si32> -> tensor<3xi32>
  %int3 = torch.constant.int 3
  %none = torch.constant.none
  %false = torch.constant.bool false
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c3 = arith.constant 3 : index
  %c0_0 = arith.constant 0 : index
  %c3_1 = arith.constant 3 : index
  %2 = tensor.empty() : tensor<3xf64>
  %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %0 : tensor<3xi32>, tensor<3xf32>) outs(%2 : tensor<3xf64>) {
  ^bb0(%in: i32, %in_6: f32, %out: f64):
    %7 = arith.sitofp %in : i32 to f64
    %8 = arith.extf %in_6 : f32 to f64
    %9 = math.powf %7, %8 : f64
    linalg.yield %9 : f64
  } -> tensor<3xf64>
  %cast = tensor.cast %3 : tensor<3xf64> to tensor<3xf64>
  %c1_2 = arith.constant 1 : index
  %c0_3 = arith.constant 0 : index
  %c3_4 = arith.constant 3 : index
  %4 = tensor.empty() : tensor<3xi32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast : tensor<3xf64>) outs(%4 : tensor<3xi32>) {
  ^bb0(%in: f64, %out: i32):
    %7 = arith.fptosi %in : f64 to i32
    linalg.yield %7 : i32
  } -> tensor<3xi32>
  %cast_5 = tensor.cast %5 : tensor<3xi32> to tensor<3xi32>
  %6 = torch_c.from_builtin_tensor %cast_5 : tensor<3xi32> -> !torch.vtensor<[3],si32>
  return %6 : !torch.vtensor<[3],si32>
}

The problem here is that arith.fptosi is explicitly rounding towards zero: https://mlir.llvm.org/docs/Dialects/ArithOps/#arithfptosi-arithfptosiop

That makes any floating point difference, producing e.g. 31.9999 instead of 32.0, cause this test failure as 31.9999 gets rounded towards zero to 31.0.

Instead, ConvertTorchToLinalg should emit some kind of round or roundeven op.

bjacob avatar Mar 12 '25 15:03 bjacob

Thanks, I'll try to put up a fix this week. It should be quite straightforward.

zjgarvey avatar Apr 03 '25 20:04 zjgarvey

Hi, what's the status on this? I would like this to be assigned to me if it hasn't been done yet

catswe avatar Jun 09 '25 16:06 catswe

Ah, I dropped the ball on this. Yeah @cats-marin feel free to pick this up.

zjgarvey avatar Jun 09 '25 16:06 zjgarvey