equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Scan-like and batched vmap-like transformations

Open michael-0brien opened this issue 5 months ago • 1 comments

Hello! I recently opened an issue in JAX about improving usability by expanding on its function transformations: https://github.com/jax-ml/jax/issues/30528. I think many can relate to often needing to fall back on jax.lax and this changing how we think about writing code. The jist of the issue is to 1) Add a function transformation for jax.lax.scan and 2) Add a transformation for batched vmapping (jax.bmap?).

I wanted to put this on the equinox team's radar since this is a community that thinks a lot about JAX usability! If this is addressed, it would be wonderful to eventually have an eqx.filter_scan and eqx.filter_bmap.

michael-0brien avatar Jul 26 '25 17:07 michael-0brien

Hey there! So I think this is one where we'd probably follow JAX's lead. If they add a jax.bmap, we'll add an eqx.filter_bmap :)

On the idea of a filter_scan (or filter_{cond,while_loop,...}) this has come up a few times. It's something we could consider adding, although at least so far the demand for it doesn't seem to have been huge - and it's easy enough to implement with eqx.partition and eqx.combine.

patrick-kidger avatar Jul 27 '25 14:07 patrick-kidger