blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Porting tests to chex by using chex runtime assertions

Open ciguaran opened this issue 1 year ago • 1 comments

Porting some tests that weren't using Chex, without loosing assertiveness. I was able to replace pure python runtime assertions by chex assertions. The only thing I couldn't port in an easy fashion are functions that are only defined for some points, for example a test longdensity that only has two possible values, depending on input. I had to replace something like:

0.0 if all(position > 59.0) else 0.5

by

jnp.where(position > 59.0, 0.0, 0.5).sum()

which is the suggested way in the Jax docs, but makes the code difficult to read. I have added comments to make it easier.

A few important guidelines and requirements before we can merge your PR:

  • [yes] If I add a new sampler, there is an issue discussing it already;
  • [yes] We should be able to understand what the PR does from its title only;
  • [yes] There is a high-level description of the changes;
  • [yes] There are links to all the relevant issues, discussions and PRs;
  • [yes] The branch is rebased on the latest main commit;
  • [yes] Commit messages follow these guidelines;
  • [yes] The code respects the current naming conventions;
  • [yes] Docstrings follow the numpy style guide
  • [yes] pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • [yes] There are tests covering the changes;
  • [No] The doc is up-to-date;
  • [NO] If I add a new sampler* I added/updated related examples

ciguaran avatar Mar 31 '23 17:03 ciguaran

Friendly ping

junpenglao avatar Oct 23 '23 09:10 junpenglao