jax
jax copied to clipboard
Implement an option for lax.while_loop to specify the maximum number of iterations, to allow reverse differentiation
Several forms of loops in JAX support reverse AD: scan
, fori_loop
with constant bounds, which is syntactic sugar for scan
. I think it could be useful to have another syntactic sugar for bounded loops by adding a parameter max_iterations
to while_loop
.