Why does `common_rewrite` call `cond_fun` before `step < max_steps`?
common_rewrite is used to rewrite cond_fun to be lax-compatible. It rewrites new_cond_fun to first compute cond_fun and then then check out & (step < max_steps) would it not be more efficient to return False immediately if step > max_steps (of course there's a chance that jit resolves this at compile time)? Is the reason to meet user expectations that certain impure computations might still occur at every step?
Haha, you're definitely getting into the weeds here.
So as step is a traced value then I'm guessing what you're getting at is that we could wrap the cond_fn in a lax.cond to avoid computing it in hte step > max_steps case?
I don't have immediate strong opinions on this. I've generally learnt to distrust the XLA compiler around control flow, which often inhibit its ability to perform optimisations across that boundary (it seems to perform very little, if any, code motion). In addition a user-provided cond_fn is often pretty cheap. (And in fact I think JAX internally also performs a bunch of rewrites here when lowering, so what you see as the interface to lax.while_loop is anyway not exactly what is passed to XLA).
On the other hand, I can see that theoretically we could skip the final evaluation of cond_fn.