efax
efax copied to clipboard
Generalized inverse gaussian
Hi Neil,
I am trying to add the generalized inverse Gaussian distribution. However there is an issue with the log bessel function of the second kind in the log normalizer and the to_exp method: https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution
There is no JAX implementation of this function https://github.com/jax-ml/jax/pull/17038, so I use this package: https://github.com/tk2lab/logbesselk
However it prevents me from using jvp in the log normalizer. In addition it does not support jax.grad: https://github.com/tk2lab/logbesselk/issues/33, so in the current version I simply use the finite difference
dlogk_dp = (logk(p + eps, z) - logk(p - eps, z)) / (2.0 * eps)
which is not accurate.
Please let me know if you have any suggestions on this issue. Tensorflow has an implementation of the log bessel function https://www.tensorflow.org/probability/api_docs/python/tfp/math/log_bessel_kve, not sure if there is a way to use it in the current framework.
Besides that the pdf and sampling should work well.
I am trying to add the generalized inverse Gaussian distribution.
Awesome!! Very cool.
Please let me know if you have any suggestions on this issue.
Have you looked in tensorflow probability? I don't know much about Bessel functions, but is this close to what you're looking for? If not, I suggest you request it there. Also, there was this issue is worth taking a look at, and maybe asking there.
What do you think?
Yes tfp.math.bessel_kve is the one I am looking for. But can I use directly in the log normalizer without breaking ExpToNat? I am new to JAX so not sure how to incorporate in the Jax framework.
But can I use directly in the log normalizer without breaking ExpToNat? I
Yes, I think so. You can see how I imported other Bessel functions into _src.tools.
Added the log_kve to _src.tools.
However test_distributions.py::test_conversion still fails. Probably because
- the nature parameters
negative_a_over_twoandnegative_b_over_twomust be negative; - the Newton method used in ExpToNat does not support constraints
- the initial parameters
initial_search_parametersare zeros
Wonder if there is any plan to
- implement constrained optimizer
- customizable
initial_search_parametersfor different distributions; in the case of GIG, we can use IG's to_nat to get the initial parameters
The other tests in test_distributions.py are passed.
the Newton method used in ExpToNat does not support constraints
Instead of adding constraints, the trick is to ensure that the flattened parametrization is over the entire plane. It may be a bit hard to understand, but the beta distribution also has constrained parameters, but it has no problem with ExpToNat because its natural parameters have support over a constrained ring. This way, ExpToNat uses RealField.flattened(map_to_plane=True, which is unconstrained.
I see you used this in your distribution, so I wonder why it's not working. I can look at it later this week if you get stuck. Just let me know.
FYI, these are the errors I get running this PR
======
FAILED tests/test_distributions.py::test_conversion[GeneralizedInverseGaussian] - Failed: Conversion failure
FAILED tests/test_entropy_gradient.py::test_nat_entropy_gradient[GeneralizedInverseGaussian] - TypeError: Gradient only defined for scalar-output functions. Output had shape: (7, 13).
FAILED tests/test_entropy_gradient.py::test_exp_entropy_gradient[GeneralizedInverseGaussian] - Failed: Non-finite gradient found for distributions: list
FAILED tests/test_hessian.py::test_sampling_cotangents[GeneralizedInverseGaussianEP] - NotImplementedError: Differentiation rule for 'random_gamma_grad' not implemented
FAILED tests/test_hessian.py::test_sampling_cotangents[GeneralizedInverseGaussianNP] - TypeError: GeneralizedInverseGaussianNP.sample() missing 1 required positional argument: 'shape'
FAILED tests/test_match_scipy.py::test_nat_entropy[GeneralizedInverseGaussian] - AssertionError:
FAILED tests/test_match_scipy.py::test_exp_entropy[GeneralizedInverseGaussian] - AssertionError:
FAILED tests/test_match_scipy.py::test_pdf[GeneralizedInverseGaussian] - AssertionError:
FAILED tests/test_match_scipy.py::test_maximum_likelihood_estimation[GeneralizedInverseGaussian] - AssertionError:
FAILED tests/test_sampling.py::test_sampling_and_estimation[GeneralizedInverseGaussianEP] - jax.errors.KeyReuseError: In pjit, argument 0 is already consumed.
FAILED tests/test_sampling.py::test_sampling_and_estimation[GeneralizedInverseGaussianNP] - jax.errors.KeyReuseError: In pjit, argument 0 is already consumed.
FAILED tests/test_shapes.py::test_shapes[GeneralizedInverseGaussian] - ValueError: Domain error in arguments. The `scale` parameter must be positive for all distributions, and many distributions have restrictions on shape parameters. Please see the `scipy.stats.geninvgauss` documentation for details.
You may want to squash and rebase onto main since it's running with some old dependencies.
Thanks! I will take a look at this in the weekend.
Cool, just checked and it looks like you forgot the constraint on p_minus_one: JaxRealArray = distribution_parameter(ScalarSupport())?
Cool, just checked and it looks like you forgot the constraint on
p_minus_one: JaxRealArray = distribution_parameter(ScalarSupport())?
This value should not have constraint, see https://en.wikipedia.org/wiki/Generalized_inverse_Gaussian_distribution
You may want to squash and rebase onto main since it's running with some old dependencies.
I should have the latest update from the main? I see your last change 6a058e67152dc09e687499ddb121fd581e8abd02 in my branch.
This seems to be a numerical issue from the log bessel function when taking its derivative from finite difference. I changed the eps from 1e-6 to 1e-10 and the Newton's method converges. However the natural parameters are not the same:
p_minus_one = jnp.array(0.9891)
negative_a_over_two = jnp.array(-3.5979)
negative_b_over_two = jnp.array(-0.4638)
gig_np = GeneralizedInverseGaussianNP(
p_minus_one=p_minus_one,
negative_a_over_two=negative_a_over_two,
negative_b_over_two=negative_b_over_two
)
gig_ep = gig_np.to_exp()
gig_np_from_ep = gig_ep.to_nat()
print(gig_np_from_ep)
print(gig_np)
print(gig_np_from_ep.to_exp())
print(gig_ep)
GeneralizedInverseGaussianNP(p_minus_one=Array(-2.00653748, dtype=float64), negative_a_over_two=Array(-1.61145512, dtype=float64), negative_b_over_two=Array(-1.30899207, dtype=float64))
GeneralizedInverseGaussianNP(p_minus_one=Array(0.9891, dtype=float64, weak_type=True), negative_a_over_two=Array(-3.5979, dtype=float64, weak_type=True), negative_b_over_two=Array(-0.4638, dtype=float64, weak_type=True))
GeneralizedInverseGaussianEP(mean_log=Array(-0.40242353, dtype=float64), mean=Array(0.77495462, dtype=float64), mean_inv=Array(1.72296084, dtype=float64))
GeneralizedInverseGaussianEP(mean_log=Array(-0.39357565, dtype=float64), mean=Array(0.77495462, dtype=float64), mean_inv=Array(1.72296084, dtype=float64))
Small difference in mean_log ends up with large difference in natural parameters.
The pdfs are not the same. I compared the results to scipy pdf to make sure the pdfs are correct:
p = gig_np.p_minus_one + 1
a = -2 * gig_np.negative_a_over_two
b = -2 * gig_np.negative_b_over_two
gig_sp = geninvgauss(p=p, b=np.sqrt(a*b), scale=np.sqrt(b/a))
p = gig_np_from_ep.p_minus_one + 1
a = -2 * gig_np_from_ep.negative_a_over_two
b = -2 * gig_np_from_ep.negative_b_over_two
gig_sp_from_ep = geninvgauss(p=p, b=np.sqrt(a*b), scale=np.sqrt(b/a))
x_values = np.linspace(0.001, 10, 1000)
plt.figure(figsize=(8, 6))
plt.plot(x_values, gig_np.pdf(x_values), label='jax pdf original')
plt.plot(x_values, gig_sp.pdf(x_values), '--', label='scipy pdf original')
plt.plot(x_values, gig_np_from_ep.pdf(x_values), label='jax pdf NatToExp')
plt.plot(x_values, gig_sp_from_ep.pdf(x_values), '--', label='scipy pdf NatToExp')
plt.legend()
When I sample random numbers from the two distributions, the sufficient statistics are highly similar:
n = 100000
x_sample = gig_sp.rvs(size=n)
x_sample_from_ep = gig_sp_from_ep.rvs(size=n)
print(f"mean log: {jnp.log(x_sample).mean()}, {jnp.log(x_sample_from_ep).mean()}")
print(f"mean: {x_sample.mean()}, {x_sample_from_ep.mean()}")
print(f"mean inv: {(1/x_sample).mean()}, {(1/x_sample_from_ep).mean()}")
mean log: -0.39378250590847924, -0.40225233681070754
mean: 0.774341182300226, 0.7746143088905203
mean inv: 1.723141011106831, 1.7219595091758464
Will do more research on this numerical instability.