jax icon indicating copy to clipboard operation
jax copied to clipboard

Improve precision of 32-bit `gammaln`?

Open lucascolley opened this issue 1 year ago • 4 comments

Description

In [2]: from jax.scipy.special import gammaln as gammaln_jax
In [5]: x = jax.numpy.asarray(2.00001)
In [7]: gammaln_jax(x)
Out[7]: Array(5.722046e-06, dtype=float32, weak_type=True)
...
In [1]: from scipy.special import gammaln
In [4]: x = np.asarray(2.00001)
In [5]: gammaln(x)
Out[5]: 4.227875597648359e-06

https://www.wolframalpha.com/input?i=ln%28Gamma%282.00001%29%29

Found in https://github.com/scipy/scipy/pull/20085#issuecomment-2117810225.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.27
jaxlib: 0.4.23.dev20240502
numpy:  1.26.4
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Lucass-MacBook-Air-4.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:41 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8103', machine='arm64')

lucascolley avatar May 17 '24 15:05 lucascolley

Thanks for the report! This looks like a float precision issue. Scipy uses 64-bit precision, while JAX uses 32-bit precision by default. If you enable 64-bit precision in JAX, you get the expected output:

In [1]: import jax
In [2]: jax.config.update('jax_enable_x64', True)
In [3]: x = jax.numpy.asarray(2.00001)
In [4]: jax.scipy.special.gammaln(x)
Out[4]: Array(4.2278756e-06, dtype=float64, weak_type=True)

That said, it seems like the 32-bit computation should be able to return a more precise answer here – in the case of gammaln, JAX is more-or-less directly calling OpenXLA's lgamma function, which is implemented here: https://github.com/openxla/xla/blob/1b9e830bc32da1e75a4bab4130376accd7e61e4d/xla/client/lib/math.cc#L513

I wonder if there's a different series expansion we could use that would be more accurate for small outputs?

jakevdp avatar May 17 '24 17:05 jakevdp

Interesting, I thought that we were enabling 64-bit JAX over in SciPy. Let me check.

lucascolley avatar May 17 '24 18:05 lucascolley

We have this in our conftest.py, so not sure why we observed the less precise value in CI:

import jax.numpy  # type: ignore[import-not-found]
xp_available_backends.update({'jax.numpy': jax.numpy})
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices(SCIPY_DEVICE)[0])

lucascolley avatar May 17 '24 18:05 lucascolley

~Feel free to close this, I re-ran CI and the test passed, so perhaps a temporary blip out of 64-bit mode for some reason.~

This was our bad, the test was sending in float32 arrays, but the tolerance was only temperamentally being violated.

lucascolley avatar May 17 '24 19:05 lucascolley