jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.scipy.special.zeta` seems to unconditionally use float32 precision

Open Gattocrucco opened this issue 1 year ago • 3 comments

Description

I enable float64 for jax, evaluate jax's and mpmath's zeta, and compare the results. Only the first 7 digits of jax correspond to mpmath:

from jax import config
config.update("jax_enable_x64", True)

from jax.scipy import special as jspecial
import mpmath

z = jspecial.zeta(2, 1).item()
z_accu = mpmath.zeta(2, 1)
print('JAX:   ', z)
print('mpmath:', z_accu)
JAX:    1.6449342726140739
mpmath: 1.64493406684823

This was not happening to me before version 0.4.16.

What jax/jaxlib version are you using?

jax 0.4.16, jaxlib 0.4.16

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.11.2, macOS 13.4

NVIDIA GPU info

No response

Gattocrucco avatar Sep 22 '23 09:09 Gattocrucco

Thanks for the report. jax.scipy.special.zeta calls directly into XLA's Zeta operation, so the best place to file this issue would be at https://github.com/openxla/xla

jakevdp avatar Sep 22 '23 19:09 jakevdp

Ok

Gattocrucco avatar Sep 22 '23 20:09 Gattocrucco

It looks like the XLA bug was marked as fixed – in the next couple days we should be able to test jax.scipy.special.zeta with the jaxlib nightly build to confirm that things look right on the JAX side.

jakevdp avatar Mar 13 '24 17:03 jakevdp

Hi @Gattocrucco

The fix provided by openxla PR #10413 appears to address this issue. I tested the mentioned code with JAX version 0.4.26 in Google Colab. The results produced by JAX closely match those of SciPy up to 15 decimal places. While both JAX and SciPy offer 16 decimal places of precision, JAX's results align with mpmath up to 13 decimal places due to mpmath's truncation at 14 digits.

from jax.scipy import special as jspecial
from scipy import special
import mpmath


z = jspecial.zeta(2, 1).item()
z_scipy = special.zeta(2, 1)
z_accu = mpmath.zeta(2, 1)
print('JAX   :', z)
print('scipy :', z_scipy)
print('mpmath:', z_accu)

Output:

JAX   : 1.6449340668482264
scipy : 1.6449340668482266
mpmath: 1.64493406684823

Please find the gist for reference.

Thank you

rajasekharporeddy avatar Apr 25 '24 07:04 rajasekharporeddy

Thanks for following up!

jakevdp avatar Apr 25 '24 12:04 jakevdp