jax
jax copied to clipboard
jnp.power gives inaccurate results
jax.__version__
'0.2.21'
import numpy as np
from jax import numpy as jnp
assert np.alltrue(2 ** np.arange(16, 20, dtype=np.uint32) == 2 ** np.arange(16, 20, dtype=np.float32))
assert jnp.alltrue(2 ** jnp.arange(16, 20, dtype=jnp.uint32) == 2 ** jnp.arange(16, 20, dtype=jnp.float32))
Thanks for the report – XLA's floating point power computation tends to be inaccurate, particularly on GPU (note that your test case passes on CPU). This is a known issue, and the reason that the results are more accurate for integer powers is because JAX specifically specializes this case to attain more accuracy: https://github.com/google/jax/blob/07083e7ec15216cc75b3e263c314e31f4169b27f/jax/_src/numpy/lax_numpy.py#L891-L909
I'm not sure whether there's been any exploration on the side of XLA:GPU to make this more accurate; perhaps @hawkinsp would know?
2 ** np.float32(16) # = 65536.0, or 0 | 10001111 | 00000000000000000000000
2 ** jnp.float32(16) # = 65535.996, or 0 | 10001110 | 11111111111111111111111
Given that the difference is only 1ULP, is there a reason (other than looking great 😎) that a higher precision is desired?
I agree. It's not reasonable to expect precise equality of floating point implementations. If there are particular values that are producing large errors that would be something worth looking into, but ultimately we're calling a pow routine from each hardware vendor and we inherit its precision.