Is it possible to have early stopping in `lax.scan`?
I am considering a case in which there is a loop of operations, each one of them being expensive; one can bound the worst case number of operations with a fixed N but one can have a stopping condition and on average the loop terminates in K operations with K << N.
With lax.while_loop it is relatively easy to implement this efficiently but unfortunately one loses reverse mode differentiation. Would it be possible to have a version of lax.scan that supports such an early stopping?
You could do this with lax.cond inside lax.scan -- just apply the identity function if the loop should terminate early.
The downside is that you will pay the price of memory allocation for all N steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.
This came up in one of our chat rooms yesterday, and it was also observed that because of how we currently implement vmap-of-cond (namely we always turn it into a select), using cond for this might lead to inefficient batching.
I think adding an early stopping version of lax.scan is a good idea. It's not trivial though. Notes to self: we'd probably need to add a jaxpr for the early-stopping function, and also add 'start' and 'stop' arguments to the scan for the linearized case.
As @shoyer observed, we'd always pay the memory cost.
Until we add this (or figure out a better plan), you might be able to work around the issue using jax.custom_vjp around your loop.
You could do this with
lax.condinsidelax.scan-- just apply the identity function if the loop should terminate early.The downside is that you will pay the price of memory allocation for all
Nsteps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.
Hi, I am trying to have an early break within lax.scan since lax.while_looptakes more time to be compiled and slower. However, I did not understand how to use lax.cond exactly to achieve that. Could you elaborate more i.e giving a small example code.
Hey @mattjj! Late to the party, but I am facing the same limitations. Is there any plan to implement this?
Sorry for bothering but IMO this feature will have a wide range of uses, e.g. stop auto-regressive generation of LLM when encountering EOS. Would really appreciate it if this could be implemented.
+1