numpyro
numpyro copied to clipboard
Using numpyro primitives in jax's control flows fori_loop, while_loop, and cond
Currently, we have scan implemented in contrib.control_flow. After #878, it should be doable to use other control flows in HMC/NUTS. The implementation should mimic what we currently have in scan (but will be much simpler) and can be addressed step-by-step:
- ~[ ] Support while_loop~
- We can assume no primitive in
cond_fnbecause any primitive incond_fncan be moved tobody_fn - We can assume body_fn only contains observed nodes because while_loop is not meant to be used to collect stuffs.
- Because body_fn only contains observed nodes, no need to worry about enumeration.
- Because body_fn does not have latent variables, no need to worry about substitute/conditon stuffs.
- We can assume no primitive in
- ~[ ] Support fori_loop~
- Any fori_loop can be rewrite as a while_loop. So supporting this should be straightforward.
- Support cond?
- [x] Support cond
- [ ] Support cond with discrete latent variable inside body_fn
- [ ] Support cond with discrete latent variable outside body_fn
- [ ] Support cond under a plate of conditions
- ~[ ] Support nested while_loop~
- [ ] Support nested scan with no latent variables (but we can collect stuff)
I think this is an interesting work and not as complicated as scan. This is a good chance to get familiar with JAX's control flow and many numpyro handlers. If anyone is interested in addressing this issue, please let me know any question that you have.
Hey @fehiepsi
I would be interested in working on some of these as an excuse to learn more about JAX and NumPyro. Do you have an example of what a model using while_loop might look like? I feel like that would help give me a bit more context to understand what's required.
In particular when you say things like "We can assume body_fn only contains observed nodes" I'm interested to know the reasoning behind a statement like that. Is that a design choice or does it actually not make sense to collect latent variables like this?
Thanks in advance for any pointers!
Hi @tcbegley, back then, the motivation for this PR is to support compiling some Stan models, which mainly uses for loop. But it turns out that #878 seems to be enough for the purpose. For body functions that have latent variables, it is better to use scan because we want to collect them (while_loop, fori_loop are not used to collect values). Back then, I forgot that to make predictions, we also need to collect values. So it does not make sense to use while_loop or fori_loop for unobserved/observed random variables. Sorry for this!
I think only supporting cond makes sense for now. But I'm not sure if there is a need for it. Do you think it is an interesting problem to pursue?
Thanks @fehiepsi, that makes a lot of sense. And thanks for the link to the paper, looks very interesting.
I think adding support for cond would be personally interesting for me, but mainly as an excuse to understand different handlers and NumPyro internals better, I agree that the need for it from the user's perspective is maybe a bit less clear.
My interest in this issue was largely driven by the "good first issue" tag, really I'm just keen to make some contributions where I can. If there are other things I could more usefully direct some attention to then let me know!
Hi @tcbegley and @fehiepsi,
That's true that we have implemented the fori_loop with a scan (see dppllib.py).
But we would still be interested by the support of cond (and also while_loop) for the compilation of Stan to NumPyro!
For example, we are not able to support in NumPyro the program of Figure 10 of our paper:
parameters {
real cluster;
real theta;
}
model {
real mu;
cluster ~ normal(0, 1);
if (cluster > 0) mu = 20;
else mu = 0;
theta ~ normal(mu, 1);
}
guide parameters {
real m1;
real m2;
real<lower=0> s1;
real<lower=0> s2;
}
guide {
cluster ~ normal(0, 1);
if (cluster > 0) theta ~ normal(m1, s1);
else theta ~ normal(m2, s2);
}
The if statement in the model can be compiled using directly the cond of JAX since there is no sampling statements in it, but the if in the guide would require a cond in NumPyro.
In the Stan models that we saw, there is more use of if than while. So if I have to pick one operator, I would prefer to have the cond. But for completeness, it would be good to also have the while_loop.