jax icon indicating copy to clipboard operation
jax copied to clipboard

gumbel distribution implementation

Open SuriyaaMM opened this issue 7 months ago • 2 comments

Implementation of gumbel distribution. issue. Scipy uses 64-bit floating point integers, so i set the tolerance in tests to 5e-3 (I was getting errors when it was 5e-4). Whenever I faced numerical issues with xlogpy and xlog1py, I replaced them with lax.log1p and lax.log. However sf, particularly in gumbel_r had issues due to calculation of log(1 - exp(exp(-z))) so I implemented _log1mexp utilising the Tensorflow implementation.

cc @jakevdp

SuriyaaMM avatar Jun 09 '25 19:06 SuriyaaMM

import scipy.stats as stats
import jax.scipy.stats as jstats
import numpy as np

print("scipy")
print(stats.gumbel_l.logpdf(np.arange(10.0).astype(np.float32), loc=0, scale=-1))
print("jax")
print(jstats.gumbel_l.logpdf(np.arange(10.0), loc=0, scale=-1))

print("scipy")
print(stats.gumbel_l.logcdf(np.arange(10.0), loc=0, scale=-1))
print("jax")
print(jstats.gumbel_l.logcdf(np.arange(10.0), loc=0, scale=-1))

print("scipy")
print(stats.gumbel_l.logsf(np.arange(10.0), loc=0, scale=-1))
print("jax")
print(jstats.gumbel_l.logsf(np.arange(10.0), loc=0, scale=-1))

print("scipy")
print(stats.gumbel_l.logpdf(np.arange(10.0).astype(np.float32), loc=0, scale=1e-32))
print("jax")
print(jstats.gumbel_l.logpdf(np.arange(10.0).astype(np.float32), loc=0, scale=1e-32))
scipy
[nan nan nan nan nan nan nan nan nan nan]
jax
WARNING:2025-06-12 09:39:47,787:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[nan nan nan nan nan nan nan nan nan nan]
scipy
[nan nan nan nan nan nan nan nan nan nan]
jax
[nan nan nan nan nan nan nan nan nan nan]
scipy
[nan nan nan nan nan nan nan nan nan nan]
jax
[nan nan nan nan nan nan nan nan nan nan]
scipy
miniconda3/envs/jax/lib/python3.11/site-packages/scipy/stats/_continuous_distns.py:4430: RuntimeWarning: overflow encountered in exp
  return x - np.exp(x)
[72.68272298        -inf        -inf        -inf        -inf        -inf
        -inf        -inf        -inf        -inf]
jax
[72.682724      -inf      -inf      -inf      -inf      -inf      -inf
      -inf      -inf      -inf]

Sorry for any inconvenience caused. This is my first time contributing. @jakevdp

SuriyaaMM avatar Jun 12 '25 04:06 SuriyaaMM

I think scale=1E-32 is the issue here: scipy will compute in float64 by default, while JAX will compute in float32 by default. With such a small scale, I suspect that float32 is underflowing while float64 is not, which leads to different results.

If you want to compare equivalent computations, you can set JAX_ENABLE_X64=true (see https://docs.jax.dev/en/latest/default_dtypes.html for info on this)

jakevdp avatar Jun 13 '25 20:06 jakevdp

@jakevdp can we proceed further or is there anything that you would like to change?

SuriyaaMM avatar Jul 02 '25 11:07 SuriyaaMM

I’ve updated the implementation as per your comments. Please take another look when convenient, happy to iterate if needed.

SuriyaaMM avatar Jul 14 '25 17:07 SuriyaaMM

Looks like this comment got lost in the mix:

Also, we'll need to add the new functionality to the docs in this section:

https://github.com/jax-ml/jax/blob/d39f29ca3a8a3faa71f35aa427927a22869e57c2/docs/jax.scipy.rst?plain=1#L200

jakevdp avatar Jul 14 '25 17:07 jakevdp

I have added it based on previously implemented distributions. Not sure whether it is right or not. If there is a way to check this, then please let me know.

SuriyaaMM avatar Jul 14 '25 17:07 SuriyaaMM

Hi! Just checking in, is there anything else I should do on this PR, or is it good to go?

SuriyaaMM avatar Jul 16 '25 18:07 SuriyaaMM

I hope everything is fine. Let me know if there is anything else that I need to change.

SuriyaaMM avatar Jul 16 '25 19:07 SuriyaaMM

Tests for Oldest Supported Numpy and CPU tests when x64 = 1 failed. Should I promote the types before returning from args_maker()? like this?

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      promote_dtype = jnp.result_type(x, loc, scale)
      scale = np.abs(scale) + 0.1  # Ensure scale > 0
      return [x.astype(promote_dtype), loc.astype(promote_dtype), scale.astype(promote_dtype)]

I'm not sure why it fails, specifically for 3.13 in actions.

echo $JAX_ENABLE_X64
1
pytest -n 2 tests/scipy_stats_test.py
================================================== test session starts ==================================================
platform linux -- Python 3.11.0, pytest-8.4.0, pluggy-1.6.0
rootdir: /home/mms/dev/jax
configfile: pyproject.toml
plugins: xdist-3.7.0, hypothesis-6.135.0
2 workers [940 items]   
................................................................................................................. [ 12%]
................................................................................................................. [ 24%]
................................................................................................................. [ 36%]
................................................................................................................. [ 48%]
................................................................................................................. [ 60%]
................................................................................................................. [ 72%]
................................................................................................................. [ 84%]
.................................................ssssssssss...................................................... [ 96%]
....................................                                                                              [100%]
====================================== 930 passed, 10 skipped in 112.76s (0:01:52) ======================================
pytest -n 2 tests/scipy_stats_test.py
================================================== test session starts ==================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/mms/dev/jax
configfile: pyproject.toml
plugins: xdist-3.8.0, hypothesis-6.135.32
2 workers [940 items]   
................................................................................................................. [ 12%]
................................................................................................................. [ 24%]
................................................................................................................. [ 36%]
................................................................................................................. [ 48%]
................................................................................................................. [ 60%]
................................................................................................................. [ 72%]
................................................................................................................. [ 84%]
..............................................................ssssssssss......................................... [ 96%]
....................................                                                                              [100%]
====================================== 930 passed, 10 skipped in 110.73s (0:01:50) ======================================

Coming to that -inf mismatch in CPU Test, I changed the code to

x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale)
  ok = lax.gt(scale, _lax_const(scale, 0))
  z = lax.div(lax.sub(x, loc), scale)
  neg_exp_z = lax.neg(lax.exp(z))
  # xlog1p fails here, that's why log1p is used here
  # even log1p fails for some cases when using float64 mode
  # so we're using this formula which is stable
  log_cdf = lax.log(-lax.expm1(neg_exp_z))
  return jnp.where(ok, log_cdf, np.nan)

and tested for edge cases specifically which weren't there in the jax test suite

edge_cases = [
            (jnp.array([-100.0]), jnp.array([0.0]), jnp.array([1.0])),
            (jnp.array([-50.0]), jnp.array([0.0]), jnp.array([1.0])),
            (jnp.array([-1000.0]), jnp.array([0.0]), jnp.array([1.0])),
            (jnp.array([0.0]), jnp.array([0.0]), jnp.array([1000.0])),
            (jnp.array([1.0]), jnp.array([0.0]), jnp.array([0.001])),
            (jnp.array([-1000.0]), jnp.array([1000.0]), jnp.array([0.1])),
            (jnp.array([-100.0, -10.0, 0.0, 10.0]), jnp.zeros(4), jnp.ones(4)),
        ]
        
        for i, (x, loc, scale) in enumerate(edge_cases):
            test_name = f"Edge case {i}"
            try:
                jax_result = lsp_stats.gumbel_l.logcdf(x, loc, scale)
                scipy_result = osp_stats.gumbel_l.logcdf(np.array(x), np.array(loc), np.array(scale))
                self.assert_allclose(jax_result, scipy_result, test_name)
            except Exception as e:
                self.log_result(test_name, False, str(e))

It passed. What's your suggestion on this? @jakevdp ?

SuriyaaMM avatar Jul 17 '25 03:07 SuriyaaMM

Perhaps we should consider exposing log1mexp to the public API, probably under nn.

Related:

  • https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
  • https://www.tensorflow.org/probability/api_docs/python/tfp/math/log1mexp
  • https://www.pymc.io/projects/docs/en/stable/api/generated/pymc.math.log1mexp.html
  • https://www.rdocumentation.org/packages/VGAM/versions/1.1-13/topics/log1mexp
  • https://juliastats.org/LogExpFunctions.jl/stable/#LogExpFunctions.log1mexp
  • https://stdlib.io/docs/api/latest/@stdlib/math/iter/special/log1mexp

carlosgmartin avatar Jul 21 '25 20:07 carlosgmartin