Scan-like and batched vmap-like transformations
Hello! I recently opened an issue in JAX about improving usability by expanding on its function transformations: https://github.com/jax-ml/jax/issues/30528. I think many can relate to often needing to fall back on jax.lax and this changing how we think about writing code. The jist of the issue is to 1) Add a function transformation for jax.lax.scan and 2) Add a transformation for batched vmapping (jax.bmap?).
I wanted to put this on the equinox team's radar since this is a community that thinks a lot about JAX usability! If this is addressed, it would be wonderful to eventually have an eqx.filter_scan and eqx.filter_bmap.
Hey there! So I think this is one where we'd probably follow JAX's lead. If they add a jax.bmap, we'll add an eqx.filter_bmap :)
On the idea of a filter_scan (or filter_{cond,while_loop,...}) this has come up a few times. It's something we could consider adding, although at least so far the demand for it doesn't seem to have been huge - and it's easy enough to implement with eqx.partition and eqx.combine.