Junpeng Lao

Results 17 issues of Junpeng Lao

I didnt manage to run the notebook before merging - we should try to get rid of the divergences.

After #210, it should be straightforward to add multi-pathfinder (ref: https://arxiv.org/pdf/2108.03782.pdf). The code snippet below mostly work (still need implementation of Pareto Smoothed important sampling). ```python multi_pathfinder = jax.vmap(lambda rng_key,...

enhancement
good first issue
sampler

Currently we cannot vmap `window_adaptation` as the one of the return is a `kernel` https://github.com/blackjax-devs/blackjax/blob/f4221d041902b472095983cd83edffb76a395f51/blackjax/kernels.py#L584-L587 Proposal to add a kwarg `return_kernel=True` to `window_adaptation`: ```python step_size, inverse_mass_matrix = final(last_warmup_state) if return_kernel:...

enhancement
adaptation

Close #317

dont-merge

Currently the doc is only showing HEAD, we should at least add latest release tag to the doc.

documentation
enhancement
good first issue

We internally flip the sign of `log_prob_fn` and use `potential_fun` instead - this start to create more headache in the code base, we should refactor to use `log_prob_fn` directly.

enhancement
refactoring

- [ ] Improve some shape handling related to batch dimension (e.g., https://github.com/blackjax-devs/blackjax/pull/229/files#r973546362) - [ ] change the per variable step size to using scalar step size and diag mass...

Hi Matt, We (me and @rlouf) are developer for BlackJAX, a MCMC library build on top of JAX with initial heavy focus on HMC. We are wondering if you are...

For background see https://github.com/google/jax/discussions/13555 `jax.experimental.host_callback` is deprecated, and we should switch to using `jax.debug.print`/`jax.debug.callback` in [blackjax/progress_bar.py](https://github.com/blackjax-devs/blackjax/blob/main/blackjax/progress_bar.py)

good first issue