numpyro
                                
                                 numpyro copied to clipboard
                                
                                    numpyro copied to clipboard
                            
                            
                            
                        Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but there appears to be some problem with splitting the random keys. Consider this toy example...
The following explains the issue. ```python import distrax import numpyro as pyro import jax.numpy as jnp import jax.random as jrandom import tensorflow_probability.substrates.jax as tfp import matplotlib.pyplot as plt key =...
Currently, `constraints.greater_than_eq` is supported in [torch](https://pytorch.org/docs/stable/distributions.html#torch.distributions.constraints.greater_than_eq), but we only have `constraints.greater_than` in numpyro. Can we add an alias for `greater_than_eq` with something like `constraints.interval(x, jnp.inf)` to the constraints namespace (and...
It would be nice to have `equinox_module` and `random_equinox_module` model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as [Equinox](https://github.com/patrick-kidger/equinox) seems to be in quite active development. Would this be a good addition? I could...
Hi there, I have an issue: I'm trying to serve a Numpyro model using mlflow and mlserver: my model has varying input sizes and needs to reestimate all parameters regularly....
Similar to #1710. The reproduced code is as follows. ``` import numpyro import numpyro.distributions as dist from numpyro.handlers import do def model(y=None): alpha = numpyro.sample("alpha", dist.Normal(0., 1.)) beta = numpyro.sample("beta",...
I am currently working on a project where we embed a VAE-decoder inside a model. Accordingly, we need to sample `z`s from a multivariate normal distribution, but we are not...
As per title. This can cause sampling bias when we have jump near `max_delta_energy`. Here's a concrete example translated from @nhuurre 's stan code at https://discourse.mc-stan.org/t/divergence-check-does-not-satisfy-time-reversibility/33738. ```python import jax import...
Here's a reproducible example that's taken nearly directly from the Gaussian Mixture Model tutorial. The AutoContinuous guide seems to be the failure mode. ```python import jax.numpy as jnp import jax.random...
Hello guys, I come from the Tensorflow Distributions world and was looking for a lightweight alternative and was pleasantly surprised to see that Pyro is available for Jax via your...