jax icon indicating copy to clipboard operation
jax copied to clipboard

Edge behavior in `jax.scipy.special.betainc`

Open mdhaber opened this issue 1 year ago • 6 comments

Description

jax.scipy.special.betainc seems to have trouble with very small values of the parameter a, at least for certain values of b and x.

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()
image

I know that it is difficult to guarantee accuracy to machine precision for all possible combinations of input : ) Just thought I'd point out this problem spot since it came up in SciPy testing (scipy/scipy#20963).

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.25.2 python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Linux', node='e901fac133dc', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sun Apr 28 14:29:16 UTC 2024', machine='x86_64')

mdhaber avatar Jun 15 '24 22:06 mdhaber

Hi @mdhaber

JAX typically uses single-precision floating-point numbers for calculations, while SciPY defaults to double precision. This difference in precision can lead to slightly different results, especially when working with very small numbers. If the double precision is enabled in JAX, then JAX yields the results that are consistent with SciPy even with very small numbers:

import jax

jax.config.update('jax_enable_x64', True)

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

image

Please find the gist for reference.

Thank you.

rajasekharporeddy avatar Jun 17 '24 04:06 rajasekharporeddy

Thanks! Although this actually came up in the context of 32-bit calculations. The definitions should have been:

a = np.logspace(-40, -1, 300, dtype=np.float32)
b = np.float32(1.)
x = np.float32(0.25)

and the plot looks the same. SciPy's outputs are float32, so I assume that's being preserved internally, although perhaps it is converting back and forth. In any case, I know the trouble area is toward the small end of normal numbers and extends into the subnormals, so I understand if it's not a priority. Feel free to close!


To zoom in:

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, betainc_jax(xp.asarray(a), b, x), label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

image

mdhaber avatar Jun 17 '24 06:06 mdhaber

Hi @mdhaber

IIUC, according to scipy/scipy/#8495 (Comment), SciPy do all the internal (low level c) calculations in float64 even if the input is float32 or other. But JAX do it in float32 itself. That might be causing this difference.

Thank you.

rajasekharporeddy avatar Jun 24 '24 09:06 rajasekharporeddy

Whatever are the reasons for scipy to use float64 internally (one practical reason could be that there are no float32 implementation available for scipy, for instance), evaluating functions using float32 correctly requires the usage of an algorithm that can properly handle overflows, underflows, or cancellations. Using higher precision is a typical cheap trick to avoid paying attention to these fp errors in implementations of the function algorithms to keep algorithms simple. So, I wonder what is the location of jax.scipy.special.betainc implementation which may provide explanations for the behavior observed in this issue.

pearu avatar Jun 24 '24 09:06 pearu

JAX's implementation is here, and mentions that it's based on http://dlmf.nist.gov/8.17.E23: https://github.com/google/jax/blob/2b728d55b6054bba8ae26b3523722e80d660e771/jax/_src/lax/special.py#L182-L190

jakevdp avatar Jun 24 '24 13:06 jakevdp

SciPy do all the internal (low level c) calculations in float64 even if the input is float32 or other.

But that comment is about scipy.ndimage.affine_transform, not scipy.special.betainc.

I confirmed with @steppi that SciPy now uses Boost's ibeta for betainc, and the types seem to be preserved in the calculation.

Here is where betainc is defined in terms of ibeta. https://github.com/scipy/scipy/blob/e36e728081475466d2faae65e1dfecfa2314c857/scipy/special/functions.json#L118-L123

Here is where ibeta is used for float and double instantiations of the function. https://github.com/scipy/scipy/blob/e36e728081475466d2faae65e1dfecfa2314c857/scipy/special/boost_special_functions.h#L106-L116

and Boost's ibeta is templated: https://beta.boost.org/doc/libs/1_68_0/libs/math/doc/html/math_toolkit/sf_beta/ibeta_function.html

mdhaber avatar Jun 24 '24 15:06 mdhaber

Checking the same code on GPU, we have a bit different plots:

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a = np.logspace(-40, -1, 300)
b = 1
x = 0.25

output = betainc_jax(xp.asarray(a), b, x)

plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, output, label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

print(output.devices(), output.dtype)
# {cuda(id=0)} float32

image

and

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import betainc as betainc_scipy
import jax.numpy as xp
from jax.scipy.special import betainc as betainc_jax

a0 = np.finfo(np.float32).smallest_normal
b = np.float32(1.)
x = np.float32(0.25)
factor = np.float32(10)
a = np.logspace(np.log10(a0), np.log10(a0*factor), 300, dtype=np.float32)

output = betainc_jax(xp.asarray(a), b, x)
plt.loglog(a, betainc_scipy(a, b, x), label='scipy')
plt.loglog(a, output, label='jax')
plt.xlabel('a')
plt.ylabel('betainc(a, 1, 0.25)')
plt.legend()

print(output.devices(), output.dtype)
# {cuda(id=0)} float32

image

So, to reproduce the issue we can add on top of the code

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"] = "cpu"

vfdev-5 avatar Jul 09 '24 05:07 vfdev-5

Thanks for the investigation @vfdev-5. I extended the top post with a description of a separate issue regarding edge case behavior, which is probably easier to address.

mdhaber avatar Feb 08 '25 19:02 mdhaber

Heads up: https://github.com/jax-ml/jax/pull/27107 fixes all issues reported here.

pearu avatar Mar 13 '25 21:03 pearu

Thank you!

mdhaber avatar Mar 20 '25 17:03 mdhaber