jax icon indicating copy to clipboard operation
jax copied to clipboard

Is it possible to have early stopping in `lax.scan`?

Open salayatana66 opened this issue 4 years ago • 7 comments

I am considering a case in which there is a loop of operations, each one of them being expensive; one can bound the worst case number of operations with a fixed N but one can have a stopping condition and on average the loop terminates in K operations with K << N.

With lax.while_loop it is relatively easy to implement this efficiently but unfortunately one loses reverse mode differentiation. Would it be possible to have a version of lax.scan that supports such an early stopping?

salayatana66 avatar Feb 05 '21 10:02 salayatana66

You could do this with lax.cond inside lax.scan -- just apply the identity function if the loop should terminate early.

The downside is that you will pay the price of memory allocation for all N steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.

shoyer avatar Feb 05 '21 21:02 shoyer

This came up in one of our chat rooms yesterday, and it was also observed that because of how we currently implement vmap-of-cond (namely we always turn it into a select), using cond for this might lead to inefficient batching.

I think adding an early stopping version of lax.scan is a good idea. It's not trivial though. Notes to self: we'd probably need to add a jaxpr for the early-stopping function, and also add 'start' and 'stop' arguments to the scan for the linearized case.

As @shoyer observed, we'd always pay the memory cost.

mattjj avatar Feb 05 '21 22:02 mattjj

Until we add this (or figure out a better plan), you might be able to work around the issue using jax.custom_vjp around your loop.

mattjj avatar Feb 06 '21 04:02 mattjj

You could do this with lax.cond inside lax.scan -- just apply the identity function if the loop should terminate early.

The downside is that you will pay the price of memory allocation for all N steps. This a limitation of XLA's memory model: all memory allocation must use statistically known shapes.

Hi, I am trying to have an early break within lax.scan since lax.while_looptakes more time to be compiled and slower. However, I did not understand how to use lax.cond exactly to achieve that. Could you elaborate more i.e giving a small example code.

dumanah avatar Jul 16 '22 17:07 dumanah

Hey @mattjj! Late to the party, but I am facing the same limitations. Is there any plan to implement this?

epignatelli avatar Jun 09 '24 16:06 epignatelli

Sorry for bothering but IMO this feature will have a wide range of uses, e.g. stop auto-regressive generation of LLM when encountering EOS. Would really appreciate it if this could be implemented.

dest1n1s avatar Nov 22 '24 06:11 dest1n1s

+1

renos avatar Dec 10 '25 22:12 renos