jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.scipy.stats has different behavior than scipy.stats when arguments are out of the support and when parameters lead to undefined densities

Open Dpananos opened this issue 4 years ago • 2 comments

Hi all,

I've been working on adding some additional distributions to jax.scipy.stats. While going through some source code from scipy, I've noted a few differences between scipy.stats and jax.scipy.stats.

Jax returns a logpmf when the argument is outside of the support

When you pass a non integer to scipy's functions, you get -inf as the logpmf

import jax.scipy.stats as lsp
import scipy.stats as osp

osp.poisson.logpmf(1.2, mu = 1)
>>>-inf

Makes sense, the poisson only supports integers. However, in jax we get a finite logpmf (hence non zero probability).

lsp.poisson.logpmf(1.2, mu = 1)
>>>Buffer(-1.0969481, dtype=float32)

Does jax want to emulate this behaviour?

Jax returns a logpmf for parameters that result in an undefined distribution

The beta distribution is defined for a>0 and b>0. Scipy returns a nan when the argument is in the support of the distribution, but the parameters result in an undefined density.


import jax.scipy.stats as lsp
import scipy.stats as osp
osp.beta.logpdf(x = 0.5, a=-1.1, b=-1.1)
>>>nan

However, jax evaluates the logpmf even for "bad" parameters

lsp.beta.logpdf(x = 0.5, a=-1.1,b=-1.1)
>>>Buffer(-0.8453665, dtype=float32)

I guess this leads to a few questions:

  • How closely do we want to emulate scipy? I think emulating scipy in the ways I've mentioned here is pretty reasonable.
  • How do we go about testing these distributions for failures of the kind I've mentioned? @jakevdp and I discussed this very briefly here, but maybe other authors would like to chime in.

Dpananos avatar Feb 05 '21 19:02 Dpananos

Thanks for raising this... it's strange that JAX returns valid values for invalid inputs. I think emulating scipy here would be reasonable.

jakevdp avatar Feb 05 '21 19:02 jakevdp

Hi @Dpananos

This issue will be resolved once the PRs #20885 and #20891 are merged.

Thank you.

rajasekharporeddy avatar Apr 23 '24 20:04 rajasekharporeddy

Hi @Dpananos

I have tested the mentioned code with JAX-nightly version 0.4.27.dev20240503. jax.scipy.stats.poisson.logpmf and jax.scipy.stats.beta.logpdf in JAX-nightly version now match SciPy's behavior:

>>> import jax
>>> jax.__version__
'0.4.27.dev20240503'
>>> import jax.scipy.stats as lsp
>>> import scipy.stats as osp
>>> osp.poisson.logpmf(1.2, mu = 1)
-inf
>>> lsp.poisson.logpmf(1.2, mu = 1)
Array(-inf, dtype=float32, weak_type=True)
>>> osp.beta.logpdf(x = 0.5, a=-1.1, b=-1.1)
nan
>>> lsp.beta.logpdf(x = 0.5, a=-1.1, b=-1.1)
Array(nan, dtype=float32, weak_type=True)
>>> 

Thank you.

rajasekharporeddy avatar May 04 '24 13:05 rajasekharporeddy

Thanks for following up!

jakevdp avatar May 06 '24 17:05 jakevdp