numpyro
numpyro copied to clipboard
Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
I would like to be able to use jax's `pjit` to parallelize density / gradient evaluation across multiple GPUs. This would allow [here](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Distributed_Inference_with_JAX.ipynb) [here](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Distributed_Inference_with_JAX.ipynb) you to perform standard NUTS/HMC when...
adds loguniform dist as in #1398. note - no docs yet. no docstrings in other distributions. how do I doc this? - no tests yet, I'm looking into it -...
I fixed the typo example ```bayesian_hierarchical_stacking.ipynb``` function ```stacking```, lines 42 and 70.
Hi, I noted that is a typo in the Bayesian Hierarchical Stacking notebook. In function ```stacking``` (cell 16), line 34, the code said: ```K = lpd_point.shape[1] # number of candidate...
How about a Log Uniform distribution, i.e., the log of a variable is uniformly distributed. Implementation: ``` class LogUniform(dist.Uniform): def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape sample =...
Hi, and first of all, thank you for this awesome project! Properties of `Transforms` object such as `codomain` rely on identity checks such as this one: https://github.com/pyro-ppl/numpyro/blob/34e0cdf4fa0ab9a0300a0d894d6758419fb46f40/numpyro/distributions/transforms.py#L173 which are not...
Some modules require more than 1 input when initializing, which can be passed through `kwargs`. But this doesn't work in some cases. For example: ```python class RNN(nn.Module): @functools.partial( nn.transforms.scan, variable_broadcast='params',...
Hi, I have been using numpyro for my research and have been finding it very useful! I had a quick question which isn't clear to me from the documentation. Is...
afaik [DiscreteHMCGibbs](https://num.pyro.ai/en/latest/_modules/numpyro/infer/hmc_gibbs.html#DiscreteHMCGibbs) does not make use of plate information when computing Gibbs updates for discrete latent variables. it would be nice to support this, as leveraging this information can make...