blackjax
blackjax copied to clipboard
Add an example that uses Aesara / Aeppl
In this PR I add an example that uses Aesara and Aeppl to build the model's logprob and compile it to JAX. Closes #173
import aesara
import aesara.tensor as at
from aeppl import joint_logprob
from aeppl.transforms import TransformValuesOpt, LogOddsTransform
# Add an improper prior on `a` and `b`
a_vv = at.scalar('a')
b_vv = at.scalar('b')
logprior = -2.5 * at.log(a_vv + b_vv)
srng = at.random.RandomStream(0)
theta_rv = srng.beta(a_vv, b_vv, size=(n_rat_tumors,))
Y_rv = srng.binomial(group_size, theta_rv)
# We apply the same transformation on `theta` as we did above
theta_vv = theta_rv.clone()
Y_vv = Y_rv.clone()
transforms_op = TransformValuesOpt(
{theta_vv: LogOddsTransform()}
)
loglikelihood = joint_logprob({Y_rv: Y_vv, theta_rv: theta_vv}, extra_rewrites=transforms_op)
# Compile a function that computes the model's logprob in the transformed space
logprob = logprior + loglikelihood
logprob_fn = aesara.function((a_vv, b_vv, theta_vv, Y_vv), logprob)
Thank you for opening a PR!
A few important guidelines and requirements before we can merge your PR:
- [x] If I add a new sampler, there is an issue discussing it already;
- [x] We should be able to understand what the PR does from its title only;
- [x] There is a high-level description of the changes;
- [x] There are links to all the relevant issues, discussions and PRs;
- [x] The branch is rebased on the latest
maincommit; - [x] Commit messages follow these guidelines;
- [x] The code respects the current naming conventions;
- [x] Docstrings follow the numpy style guide
- [ ]
pre-commitis installed and configured on your machine, and you ran it before opening the PR; - [x] There are tests covering the changes;
- [ ] The doc is up-to-date;
- [x] If I add a new sampler* I added/updated related examples
Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.
Update on this PR?
This is straightforward, the only thing blocking me is my schedule.