numpyro
numpyro copied to clipboard
Typo in the example Bayesian Hierarchical Stacking
Hi, I noted that is a typo in the Bayesian Hierarchical Stacking notebook. In function stacking
(cell 16), line 34, the code said:
K = lpd_point.shape[1] # number of candidate models
However, the function argument is exp_lpd_point
. The example runs because Python allows functions to use variables outside its scope. The same type appears in the same function on line 62:
logp = jax.nn.logsumexp(lpd_point + log_w, axis=1)
Best
Hi @cpieringer, thanks for the catch! Do you want to submit the fix?
Hi @fehiepsi, I will try.
Hi @fehiepsi, I was fixing the issue, and I noted that somebody has already fixed it. However, I have a doubt about line 70 in def stack: logp = jax.nn.logsumexp(exp_lpd_point + log_w, axis=1)
In this case, the function receives a version of exp(lpd_point). It is necessary to apply again logsumexp
inside the function stack?
I think we need to use jax.nn.logsumexp(lpd_point + log_w, axis=1)
. We can also reuse exp_lpd_point
with
w = numpyro.deterministic("w", jnp.exp(log_w))
logp = jnp.log((w * exp_lpd_point).sum(1))
but using logsumexp
is more numerical stable.
Ok. I can change the parameters on mcmc.run
to just passing lpd_point into the function. Do you agree?
mcmc.run(
jax.random.PRNGKey(17),
X=X_stacking_train,
d_discrete=4,
X_test=X_stacking_test,
exp_lpd_point=lpd_point,
tau_mu=1.0,
tau_sigma=0.5,
test=True,
)
Yeah, sounds good to me. Could you also rename the signature of stack function exp_lpd_point->lpd_point
as you suggested in the PR? At the mcmc run, it would be mcmc.run(..., lpd_point=lpd_point)
.
This is fixed in #1480. Thanks @cpieringer!