blackjax
blackjax copied to clipboard
Porting tests to chex by using chex runtime assertions
Porting some tests that weren't using Chex, without loosing assertiveness. I was able to replace pure python runtime assertions by chex assertions. The only thing I couldn't port in an easy fashion are functions that are only defined for some points, for example a test longdensity that only has two possible values, depending on input. I had to replace something like:
0.0 if all(position > 59.0) else 0.5
by
jnp.where(position > 59.0, 0.0, 0.5).sum()
which is the suggested way in the Jax docs, but makes the code difficult to read. I have added comments to make it easier.
A few important guidelines and requirements before we can merge your PR:
- [yes] If I add a new sampler, there is an issue discussing it already;
- [yes] We should be able to understand what the PR does from its title only;
- [yes] There is a high-level description of the changes;
- [yes] There are links to all the relevant issues, discussions and PRs;
- [yes] The branch is rebased on the latest
main
commit; - [yes] Commit messages follow these guidelines;
- [yes] The code respects the current naming conventions;
- [yes] Docstrings follow the numpy style guide
- [yes]
pre-commit
is installed and configured on your machine, and you ran it before opening the PR; - [yes] There are tests covering the changes;
- [No] The doc is up-to-date;
- [NO] If I add a new sampler* I added/updated related examples
Friendly ping