iree
iree copied to clipboard
stablehlo.sort not working correctly on Metal
What happened?
I'm calling things from Elixir/Nx, but the underlying MLIR module is the following:
module {
func.func public @main(%arg0: tensor<6xi32>) -> tensor<6xi32> {
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = false}> ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):
%1 = stablehlo.compare LT, %arg1, %arg2, NOTYPE : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
}) : (tensor<6xi32>) -> tensor<6xi32>
return %0 : tensor<6xi32>
}
}
Compiler flags for local-sync are:
flags = [
"--iree-hal-target-backends=llvm-cpu",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
Compiler flags for metal are:
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
On metal, I get:
Nx.sort(Nx.tensor([16, 23, 42, 4, 8, 15]))
#Nx.Tensor<
s32[6]
NxIREE.Backend(metal://000000010000092e)
[8, 16, 23, 23, 23, 42]
>
On local-sync:
Nx.sort(Nx.tensor([16, 23, 42, 4, 8, 15]))
#Nx.Tensor<
s32[6]
NxIREE.Backend(local-sync://)
[4, 8, 15, 16, 23, 42]
>
Also, in the same vein, for local sync, NaNs sort as larger than everything else on local-sync:
iex(35)> Nx.sort(t)
#Nx.Tensor<
f32[4]
NxIREE.Backend(local-sync://)
[-Inf, 0.0, Inf, NaN]
>
iex(36)> Nx.sort(t, direction: :desc)
#Nx.Tensor<
f32[4]
NxIREE.Backend(local-sync://)
[NaN, Inf, 0.0, -Inf]
>
But on metal, they are sorting at the front in all cases:
iex(38)> Nx.sort(t)
#Nx.Tensor<
f32[4]
NxIREE.Backend(metal://000000010000092e)
[NaN, -Inf, 0.0, Inf]
>
iex(39)> Nx.sort(t, direction: :desc)
#Nx.Tensor<
f32[4]
NxIREE.Backend(metal://000000010000092e)
[NaN, Inf, 0.0, -Inf]
>
Steps to reproduce your issue
No response
What component(s) does this issue relate to?
No response
Version information
IREE tag used: candidate-20240822.993
Additional context
No response
Not sure why this would be failing, but I would put stablehlo.sort in the category of things that are only lightly tested and used in models most of the core team focuses on.
Brainstorming ways to improve test coverage:
- Get continuous tests running on Metal again (need a mac runner, GitHub's free mac runners don't support Metal / using the GPU)
- Check if the test here is still representative: https://github.com/iree-org/iree/blob/main/tests/e2e/stablehlo_ops/sort.mlir - it was passing on Metal at some point: https://github.com/iree-org/iree/blob/afe18d222e840644bf78f103fe584610c7c9d04c/tests/e2e/stablehlo_ops/CMakeLists.txt#L620
- Import these upstream tests to iree-test-suites (see https://github.com/iree-org/iree-test-suites/issues/4): https://github.com/openxla/stablehlo/blob/e44958c8ac3178df0d69d46804960df8e92580b1/stablehlo/tests/ops_stablehlo.mlir#L2876-L2966 and https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/interpret/sort.mlir , then run those tests on all backends