gumbel distribution implementation
Implementation of gumbel distribution. issue.
Scipy uses 64-bit floating point integers, so i set the tolerance in tests to 5e-3 (I was getting errors when it was 5e-4).
Whenever I faced numerical issues with xlogpy and xlog1py, I replaced them with lax.log1p and lax.log. However sf, particularly in gumbel_r had issues due to calculation of log(1 - exp(exp(-z))) so I implemented _log1mexp utilising the Tensorflow implementation.
cc @jakevdp
import scipy.stats as stats
import jax.scipy.stats as jstats
import numpy as np
print("scipy")
print(stats.gumbel_l.logpdf(np.arange(10.0).astype(np.float32), loc=0, scale=-1))
print("jax")
print(jstats.gumbel_l.logpdf(np.arange(10.0), loc=0, scale=-1))
print("scipy")
print(stats.gumbel_l.logcdf(np.arange(10.0), loc=0, scale=-1))
print("jax")
print(jstats.gumbel_l.logcdf(np.arange(10.0), loc=0, scale=-1))
print("scipy")
print(stats.gumbel_l.logsf(np.arange(10.0), loc=0, scale=-1))
print("jax")
print(jstats.gumbel_l.logsf(np.arange(10.0), loc=0, scale=-1))
print("scipy")
print(stats.gumbel_l.logpdf(np.arange(10.0).astype(np.float32), loc=0, scale=1e-32))
print("jax")
print(jstats.gumbel_l.logpdf(np.arange(10.0).astype(np.float32), loc=0, scale=1e-32))
scipy
[nan nan nan nan nan nan nan nan nan nan]
jax
WARNING:2025-06-12 09:39:47,787:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[nan nan nan nan nan nan nan nan nan nan]
scipy
[nan nan nan nan nan nan nan nan nan nan]
jax
[nan nan nan nan nan nan nan nan nan nan]
scipy
[nan nan nan nan nan nan nan nan nan nan]
jax
[nan nan nan nan nan nan nan nan nan nan]
scipy
miniconda3/envs/jax/lib/python3.11/site-packages/scipy/stats/_continuous_distns.py:4430: RuntimeWarning: overflow encountered in exp
return x - np.exp(x)
[72.68272298 -inf -inf -inf -inf -inf
-inf -inf -inf -inf]
jax
[72.682724 -inf -inf -inf -inf -inf -inf
-inf -inf -inf]
Sorry for any inconvenience caused. This is my first time contributing. @jakevdp
I think scale=1E-32 is the issue here: scipy will compute in float64 by default, while JAX will compute in float32 by default. With such a small scale, I suspect that float32 is underflowing while float64 is not, which leads to different results.
If you want to compare equivalent computations, you can set JAX_ENABLE_X64=true (see https://docs.jax.dev/en/latest/default_dtypes.html for info on this)
@jakevdp can we proceed further or is there anything that you would like to change?
I’ve updated the implementation as per your comments. Please take another look when convenient, happy to iterate if needed.
Looks like this comment got lost in the mix:
Also, we'll need to add the new functionality to the docs in this section:
https://github.com/jax-ml/jax/blob/d39f29ca3a8a3faa71f35aa427927a22869e57c2/docs/jax.scipy.rst?plain=1#L200
I have added it based on previously implemented distributions. Not sure whether it is right or not. If there is a way to check this, then please let me know.
Hi! Just checking in, is there anything else I should do on this PR, or is it good to go?
I hope everything is fine. Let me know if there is anything else that I need to change.
Tests for Oldest Supported Numpy and CPU tests when x64 = 1 failed. Should I promote the types before returning from args_maker()? like this?
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
promote_dtype = jnp.result_type(x, loc, scale)
scale = np.abs(scale) + 0.1 # Ensure scale > 0
return [x.astype(promote_dtype), loc.astype(promote_dtype), scale.astype(promote_dtype)]
I'm not sure why it fails, specifically for 3.13 in actions.
echo $JAX_ENABLE_X64
1
pytest -n 2 tests/scipy_stats_test.py
================================================== test session starts ==================================================
platform linux -- Python 3.11.0, pytest-8.4.0, pluggy-1.6.0
rootdir: /home/mms/dev/jax
configfile: pyproject.toml
plugins: xdist-3.7.0, hypothesis-6.135.0
2 workers [940 items]
................................................................................................................. [ 12%]
................................................................................................................. [ 24%]
................................................................................................................. [ 36%]
................................................................................................................. [ 48%]
................................................................................................................. [ 60%]
................................................................................................................. [ 72%]
................................................................................................................. [ 84%]
.................................................ssssssssss...................................................... [ 96%]
.................................... [100%]
====================================== 930 passed, 10 skipped in 112.76s (0:01:52) ======================================
pytest -n 2 tests/scipy_stats_test.py
================================================== test session starts ==================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/mms/dev/jax
configfile: pyproject.toml
plugins: xdist-3.8.0, hypothesis-6.135.32
2 workers [940 items]
................................................................................................................. [ 12%]
................................................................................................................. [ 24%]
................................................................................................................. [ 36%]
................................................................................................................. [ 48%]
................................................................................................................. [ 60%]
................................................................................................................. [ 72%]
................................................................................................................. [ 84%]
..............................................................ssssssssss......................................... [ 96%]
.................................... [100%]
====================================== 930 passed, 10 skipped in 110.73s (0:01:50) ======================================
Coming to that -inf mismatch in CPU Test, I changed the code to
x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
neg_exp_z = lax.neg(lax.exp(z))
# xlog1p fails here, that's why log1p is used here
# even log1p fails for some cases when using float64 mode
# so we're using this formula which is stable
log_cdf = lax.log(-lax.expm1(neg_exp_z))
return jnp.where(ok, log_cdf, np.nan)
and tested for edge cases specifically which weren't there in the jax test suite
edge_cases = [
(jnp.array([-100.0]), jnp.array([0.0]), jnp.array([1.0])),
(jnp.array([-50.0]), jnp.array([0.0]), jnp.array([1.0])),
(jnp.array([-1000.0]), jnp.array([0.0]), jnp.array([1.0])),
(jnp.array([0.0]), jnp.array([0.0]), jnp.array([1000.0])),
(jnp.array([1.0]), jnp.array([0.0]), jnp.array([0.001])),
(jnp.array([-1000.0]), jnp.array([1000.0]), jnp.array([0.1])),
(jnp.array([-100.0, -10.0, 0.0, 10.0]), jnp.zeros(4), jnp.ones(4)),
]
for i, (x, loc, scale) in enumerate(edge_cases):
test_name = f"Edge case {i}"
try:
jax_result = lsp_stats.gumbel_l.logcdf(x, loc, scale)
scipy_result = osp_stats.gumbel_l.logcdf(np.array(x), np.array(loc), np.array(scale))
self.assert_allclose(jax_result, scipy_result, test_name)
except Exception as e:
self.log_result(test_name, False, str(e))
It passed. What's your suggestion on this? @jakevdp ?
Perhaps we should consider exposing log1mexp to the public API, probably under nn.
Related:
- https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
- https://www.tensorflow.org/probability/api_docs/python/tfp/math/log1mexp
- https://www.pymc.io/projects/docs/en/stable/api/generated/pymc.math.log1mexp.html
- https://www.rdocumentation.org/packages/VGAM/versions/1.1-13/topics/log1mexp
- https://juliastats.org/LogExpFunctions.jl/stable/#LogExpFunctions.log1mexp
- https://stdlib.io/docs/api/latest/@stdlib/math/iter/special/log1mexp