jax
jax copied to clipboard
`jax.scipy.special.zeta` seems to unconditionally use float32 precision
Description
I enable float64 for jax, evaluate jax's and mpmath's zeta, and compare the results. Only the first 7 digits of jax correspond to mpmath:
from jax import config
config.update("jax_enable_x64", True)
from jax.scipy import special as jspecial
import mpmath
z = jspecial.zeta(2, 1).item()
z_accu = mpmath.zeta(2, 1)
print('JAX: ', z)
print('mpmath:', z_accu)
JAX: 1.6449342726140739
mpmath: 1.64493406684823
This was not happening to me before version 0.4.16.
What jax/jaxlib version are you using?
jax 0.4.16, jaxlib 0.4.16
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.11.2, macOS 13.4
NVIDIA GPU info
No response
Thanks for the report. jax.scipy.special.zeta
calls directly into XLA's Zeta operation, so the best place to file this issue would be at https://github.com/openxla/xla
Ok
It looks like the XLA bug was marked as fixed – in the next couple days we should be able to test jax.scipy.special.zeta
with the jaxlib nightly build to confirm that things look right on the JAX side.
Hi @Gattocrucco
The fix provided by openxla PR #10413 appears to address this issue. I tested the mentioned code with JAX version 0.4.26 in Google Colab. The results produced by JAX closely match those of SciPy up to 15 decimal places. While both JAX and SciPy offer 16 decimal places of precision, JAX's results align with mpmath up to 13 decimal places due to mpmath's truncation at 14 digits.
from jax.scipy import special as jspecial
from scipy import special
import mpmath
z = jspecial.zeta(2, 1).item()
z_scipy = special.zeta(2, 1)
z_accu = mpmath.zeta(2, 1)
print('JAX :', z)
print('scipy :', z_scipy)
print('mpmath:', z_accu)
Output:
JAX : 1.6449340668482264
scipy : 1.6449340668482266
mpmath: 1.64493406684823
Please find the gist for reference.
Thank you
Thanks for following up!