jax icon indicating copy to clipboard operation
jax copied to clipboard

Implement an option for lax.while_loop to specify the maximum number of iterations, to allow reverse differentiation

Open gnecula opened this issue 4 years ago • 3 comments

Several forms of loops in JAX support reverse AD: scan, fori_loop with constant bounds, which is syntactic sugar for scan. I think it could be useful to have another syntactic sugar for bounded loops by adding a parameter max_iterations to while_loop.

gnecula avatar Mar 20 '20 11:03 gnecula