numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Typo in the example Bayesian Hierarchical Stacking

Open cpieringer opened this issue 2 years ago • 2 comments

Hi, I noted that is a typo in the Bayesian Hierarchical Stacking notebook. In function stacking (cell 16), line 34, the code said:

K = lpd_point.shape[1] # number of candidate models

However, the function argument is exp_lpd_point. The example runs because Python allows functions to use variables outside its scope. The same type appears in the same function on line 62:

logp = jax.nn.logsumexp(lpd_point + log_w, axis=1)

Best

cpieringer avatar May 16 '22 22:05 cpieringer

Hi @cpieringer, thanks for the catch! Do you want to submit the fix?

fehiepsi avatar May 17 '22 14:05 fehiepsi

Hi @fehiepsi, I will try.

cpieringer avatar May 19 '22 18:05 cpieringer

Hi @fehiepsi, I was fixing the issue, and I noted that somebody has already fixed it. However, I have a doubt about line 70 in def stack: logp = jax.nn.logsumexp(exp_lpd_point + log_w, axis=1)

In this case, the function receives a version of exp(lpd_point). It is necessary to apply again logsumexp inside the function stack?

cpieringer avatar Sep 14 '22 21:09 cpieringer

I think we need to use jax.nn.logsumexp(lpd_point + log_w, axis=1). We can also reuse exp_lpd_point with

w = numpyro.deterministic("w", jnp.exp(log_w))
logp = jnp.log((w * exp_lpd_point).sum(1))

but using logsumexp is more numerical stable.

fehiepsi avatar Sep 14 '22 21:09 fehiepsi

Ok. I can change the parameters on mcmc.run to just passing lpd_point into the function. Do you agree?

mcmc.run(
    jax.random.PRNGKey(17),
    X=X_stacking_train,
    d_discrete=4,
    X_test=X_stacking_test,
    exp_lpd_point=lpd_point,
    tau_mu=1.0,
    tau_sigma=0.5,
    test=True,
)

cpieringer avatar Sep 14 '22 22:09 cpieringer

Yeah, sounds good to me. Could you also rename the signature of stack function exp_lpd_point->lpd_point as you suggested in the PR? At the mcmc run, it would be mcmc.run(..., lpd_point=lpd_point).

fehiepsi avatar Sep 15 '22 07:09 fehiepsi

This is fixed in #1480. Thanks @cpieringer!

fehiepsi avatar Sep 16 '22 21:09 fehiepsi