argsort incorrectly handles very small floating-point numbers and -0.0 compared to PyTorch
Description
Description: When using JAX's argsort function on an array containing small floating-point numbers, as well as 0.0 and -0.0, the sorting order is incorrect compared to other libraries like PyTorch.
Specifically, JAX incorrectly places the very small positive number 1.401298464324817e-45 before both 0.0 and -0.0. Expected behavior is that both 0.0 and -0.0 should be treated as equivalent and placed before any positive numbers, including very small values like 1.401298464324817e-45. PyTorch demonstrates the correct behavior in this case.
import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp
def test_argsort():
# Input data, hardcoded as float32
input_data = np.array([
-0.0, 1.401298464324817e-45, 1.100000023841858, -0.0,
5.960464477539063e-08, -2.0000100135803223, 1000000.0,
722801.375, 0.0, -1.100000023841858
], dtype=np.float32)
# PyTorch argsort
pytorch_result = torch.argsort(torch.tensor(input_data, dtype=torch.float32)).numpy()
print(f"PyTorch argsort result: {pytorch_result}")
# TensorFlow argsort
tensorflow_result = tf.argsort(input_data).numpy().astype(np.int32)
print(f"TensorFlow argsort result: {tensorflow_result}")
# JAX argsort
jax_result = jnp.argsort(input_data).astype(np.int32)
print(f"JAX argsort result: {jax_result}")
if __name__ == "__main__":
test_argsort()
PyTorch argsort result: [5 9 0 3 8 1 4 2 7 6]
TensorFlow argsort result: [5 9 0 1 3 8 4 2 7 6]
JAX argsort result: [5 9 0 1 3 8 4 2 7 6]
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:38:46) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Lily的电脑', release='10', version='10.0.22631', machine='AMD64')
JAX arrays on CPU use FTZ mode, that is, subnormal numbers (like 1e-45) are flushed to zeros:
>>> with jax.default_device(jax.devices('cpu')[0]):
... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
...
Array([ True, True, True], dtype=bool)
>>> with jax.default_device(jax.devices('cuda')[0]):
... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
...
Array([ True, False, True], dtype=bool)
So, the issue is not in argsort but in using FTZ mode in general on CPU.
JAX arrays on CPU use FTZ mode, that is, subnormal numbers (like 1e-45) are flushed to zeros:
>>> with jax.default_device(jax.devices('cpu')[0]): ... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0 ... Array([ True, True, True], dtype=bool) >>> with jax.default_device(jax.devices('cuda')[0]): ... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0 ... Array([ True, False, True], dtype=bool)So, the issue is not in
argsortbut in using FTZ mode in general on CPU. Thank you for your response, but I'd like to emphasize that the main issue is with how JAX handles-0.0inargsort. 1. Handling of-0.0inargsort: According to the IEEE 754 standard,-0.0and0.0should be treated as equal. However, in JAX'sargsort, it seems that-0.0is treated differently from0.0, leading to an incorrect sorting order. In my test case, JAX'sargsortreturns an index for-0.0that suggests it's not equal to0.0, which is not the expected behavior. In contrast, PyTorch correctly handles this case by treating-0.0and0.0as equal, resulting in the expected sorting order. 2. FTZ (Flush to Zero) mode on CPU: While the FTZ mode may explain the handling of subnormal numbers like1.401298464324817e-45, the core issue in this particular case is the treatment of-0.0. PyTorch, also running on the same CPU, does not exhibit this issue, suggesting that the problem is not inherent to the CPU but rather how JAX is handling-0.0in its sorting operations. The incorrect handling of-0.0is the root cause of the inconsistentargsortresults. Would it be possible to review how JAX is dealing with-0.0inargsortand ensure it conforms to the IEEE standard where-0.0and0.0are considered equal?
In processors that support the FTZ flag, enabling FTZ is optional. PyTorch obviously does not enable FTZ mode while JAX (read: some of its underlying component) does on CPU.
in JAX's argsort, it seems that -0.0 is treated differently from 0.0
Looking at your test output, I would conclude that -0.0, 1e-45, 0.0 are all treated as equal because in argsort output, their relative order is unchanged (as per its stable=True option).
In processors that support the FTZ flag, enabling FTZ is optional. PyTorch obviously does not enable FTZ mode while JAX (read: some of its underlying component) does on CPU.
in JAX's argsort, it seems that -0.0 is treated differently from 0.0
Looking at your test output, I would conclude that
-0.0,1e-45,0.0are all treated as equal because inargsortoutput, their relative order is unchanged (as per itsstable=Trueoption).
- CPU and GPU output should be consistent: According to the IEEE 754 standard, subnormal numbers (such as 1e-45) should not produce different results on the CPU and GPU. JAX uses the FTZ (Flush to Zero) mode on the CPU to flush these very small values to zero, but this mode is not enabled on the GPU, resulting in differences in sorting and comparison results.
However, users expect consistent results on all hardware platforms, especially in basic operations such as argsort. PyTorch is able to maintain consistent behavior on the CPU and GPU, indicating that this is a problem in the JAX implementation rather than a limitation of the hardware itself. Consistency on different hardware platforms is an important principle that numerical computing frameworks should follow.
- PyTorch handles it correctly: In PyTorch, the results of subnormal number processing are consistent regardless of CPU or GPU. In particular, in this case, PyTorch returns the correct sort result because it does not change the processing of these small numbers due to different platforms.
In contrast, JAX's inconsistent behavior indicates that its handling of FTZ mode is not in compliance with the standard, especially when performing basic operations like argsort. This behavior leads to inconsistent results between CPU and GPU, which introduces potential bugs. JAX should behave consistently across all platforms, just like PyTorch does.
We confirmed in related issue (#24281) that JAX handles signed zeros consistently in operations like argsort and argmax.
With that out of the way, it seems your main concern is that JAX treats subnormal numbers differently depending on the backend. Is that correct?
We confirmed in related issue (#24281) that JAX handles signed zeros consistently in operations like
argsortandargmax.With that out of the way, it seems your main concern is that JAX treats subnormal numbers differently depending on the backend. Is that correct?
Yes, ~because I am currently testing deep learning libraries, I may submit some bugs to you in a short period of time. For our evaluation criteria, the output results of different deep learning libraries under the same conditions should be consistent if there is no problem. Secondly, when it comes to calculation accuracy issues, we believe that special cases need to be clearly noted and marked. For these cases, for example, pytorch can handle this problem well, so we believe that jax's performance should also be consistent and correct if there are no other conditions. Secondly, according to the CPU and GPU examples you provided, it is obvious that the output for the same set of data inputs shows differences, which I think meets the requirements of bugs.~
I edited your response to what would have been the most helpful.
FWIW, handling subnormals in a device-dependent way complicates testing on samples with subnormals. Here's another example of the issue reported here: https://github.com/pearu/functional_algorithms/issues/38#issuecomment-2366843504 where the results of evaluating math functions near branch cuts depend if subnormals are flushed or not.
Enabling FTZ is an optimization method and, imho, there should exist a method (jax.config/environment variable/...) that allows controlling the state of the FTZ flag by user programs or testing scripts.
cc/ @hawkinsp do you know whether it would be feasible to allow user-configurable FTZ semantics?