jax
jax copied to clipboard
Type promotion for uint32 is broken when x64 is disabled
With x64 mode disabled, type promotion for 32-bit unsigned ints can lead to data corruption:
from jax.numpy.lax_numpy import _promote_dtypes
import numpy as np
a = np.uint32(2 ** 32 - 1)
b = np.int8(1)
print(a, b)
# 4294967295 1
print(*_promote_dtypes(a, b))
# -1 1
The issue is that type promotion here results in int64, which when x64 mode is disabled is truncated to int32 by canonicalize_dtype, and int32 cannot represent the full range of uint32 values.
The situation is analogous to numpy's promotion of uint64:
>>> np.promote_types(np.uint64, np.int8)
np.float64
Because the largest available signed integer (int64) cannot represent the full range of uint64 values, it promotes to float64.
In JAX with x64 mode disabled, we should probably similarly promote uint32 to float32 rather than to int32.
Hi @jakevdp
Since lax_numpy was removed from jax.numpy with PR #10029 and made it private, I checked this issue with jax.numpy.promote_types. It behaves identical to numpy.promote_types for uint types even when x64 mode is disabled.
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", False)
a = np.uint32(2 ** 32 - 1)
b = np.int8(1)
print(a, b)
# 4294967295 1
print(jnp.promote_types(a, b))
# 'int64'
print(np.promote_types(a, b))
# 'int64'
print(np.promote_types(np.uint32, np.int8), np.promote_types(np.uint64, np.int32))
# (dtype('int64'), dtype('float64'))
print(jnp.promote_types(jnp.uint32, jnp.int8), jnp.promote_types(jnp.uint64, jnp.int32))
# (dtype('int64'), dtype('float64'))
Output:
4294967295 1
int64
int64
int64 float64
int64 float64
Kindly find the gist for reference.
Thank you
Hi @rajasekharporeddy – thanks for the followup, but jnp.promote_types is not the correct point of comparison here. An updated version of my code from 2020 (which used private utilities then, and still uses private utilities now) would be this, and it still displays the same output:
from jax._src.numpy.util import promote_dtypes
import numpy as np
a = np.uint32(2 ** 32 - 1)
b = np.int8(1)
print(a, b)
# 4294967295 1
print(*promote_dtypes(a, b))
# -1 1
I've looked into changing this a couple times, but it leads to a surprising number of failures both within JAX and in downstream projects.