bayes icon indicating copy to clipboard operation
bayes copied to clipboard

New window_adaptation syntax

Open forgi86 opened this issue 1 year ago • 4 comments

Hello, I an new to blackjax and recently came across your nice examples. However, I had to change a few line of codes to use recent versions of the jax/blackjax ecosystem, in particular the adaptation algorithm. For instance, in bayesian-neural-network.ipynb, I had to change the lines below the definition of the "potential" to:

warmup = blackjax.window_adaptation(blackjax.nuts, potential)
(state, parameters), _ = warmup.run(key_warmup, params)

kernel = blackjax.nuts(potential, **parameters).step
states = inference_loop(key_samples, kernel, state, num_steps)
sampled_params = states.position

forgi86 avatar Sep 17 '24 08:09 forgi86

Hi @forgi86,

Could you please make a PR request with these new changes?

Thanks!

gerdm avatar Sep 17 '24 09:09 gerdm

Hello @gerdm,

I opened a PR, but only fixed the bayesian-neural-network example. I was trying to fix the bnn-hierarchical-flax example (that is the most interesting for my current activities), but there seems to be another bug there.

I can't evaluate the potential, if I run

potential(params_all)

it throws a ValueError: Arity mismatch between trees

Unfortunately I don't have time to look into it at the moment...

forgi86 avatar Sep 17 '24 15:09 forgi86

OK, I also fixed the bug in the hbnn potential. Variable params_sigma_tree in build_sigma_tree had one more singletone initial dimension than needed (perhaps a change in linen's pytree structure?)

All fixed in my fork (https://github.com/forgi86/bayes). I also made a few changes to remove warnings for deprecated stuff.

forgi86 avatar Sep 19 '24 09:09 forgi86

Thanks for your contribution @forgi86!

I'll take a look at this tonight.

gerdm avatar Sep 19 '24 09:09 gerdm

Closed by #3

gerdm avatar Oct 02 '24 17:10 gerdm