jax icon indicating copy to clipboard operation
jax copied to clipboard

jnp.power gives inaccurate results

Open jsebestyen opened this issue 4 years ago • 1 comments

jax.__version__
'0.2.21'
import numpy as np
from jax import numpy as jnp

assert np.alltrue(2 ** np.arange(16, 20, dtype=np.uint32) == 2 ** np.arange(16, 20, dtype=np.float32))
assert jnp.alltrue(2 ** jnp.arange(16, 20, dtype=jnp.uint32) == 2 ** jnp.arange(16, 20, dtype=jnp.float32))

jsebestyen avatar Oct 02 '21 11:10 jsebestyen

Thanks for the report – XLA's floating point power computation tends to be inaccurate, particularly on GPU (note that your test case passes on CPU). This is a known issue, and the reason that the results are more accurate for integer powers is because JAX specifically specializes this case to attain more accuracy: https://github.com/google/jax/blob/07083e7ec15216cc75b3e263c314e31f4169b27f/jax/_src/numpy/lax_numpy.py#L891-L909

I'm not sure whether there's been any exploration on the side of XLA:GPU to make this more accurate; perhaps @hawkinsp would know?

jakevdp avatar Oct 02 '21 14:10 jakevdp

2 ** np.float32(16)   # = 65536.0,   or 0 | 10001111 | 00000000000000000000000
2 ** jnp.float32(16)  # = 65535.996, or 0 | 10001110 | 11111111111111111111111

Given that the difference is only 1ULP, is there a reason (other than looking great 😎) that a higher precision is desired?

yhtang avatar Sep 14 '22 18:09 yhtang

I agree. It's not reasonable to expect precise equality of floating point implementations. If there are particular values that are producing large errors that would be something worth looking into, but ultimately we're calling a pow routine from each hardware vendor and we inherit its precision.

hawkinsp avatar Sep 14 '22 19:09 hawkinsp