jax
jax copied to clipboard
Tracker: decomposing scan (aka "Five Loop")
With jaxprs now supporting effect types, we can express side-effects like the State monad, where references can be read from and written to (i.e. mutation). We can use state to implement a simpler scan
control flow primitive via a for
primitive that supports reads/writes.
This issue will track the implementation progress:
- [x]
get/swap/addupdate
primitives- [x]
impl
rules - [x]
abstract_eval
rules - [x]
jvp
rules - [x]
transpose
rules - [x]
vmap
rules
- [x]
- [x] Discharging state
- [x] Basic implementation
- [x] Handling higher-order primitives
- [ ]
for
primitive- [x]
impl
rule - [x]
abstract_eval
rule - [x] MLIR lowering
- [x]
jvp
rule - [ ]
partial_eval
- [x] basic implementation
- [ ] optimizations
- [x] loop invariant
- [x] ~make loop index a ref and~ use read/writes to determine which values are loop-invariant
- [ ] residual passthrough
- [ ] rematerializing loop-dependent values
- [x]
transpose
- [x]
vmap
rule - [x]
partial_eval_custom
rule - [x] Miscellaneous
- [x] Handling closed-over refs
- [x] Nested for loops
- [x] Unrolling
- [x] Reimplement
scan
in terms offor
- [x]
The "raw" version of for
can be found here. Next steps involve porting that code to JAX core and adding tests.
Hi, for a non PL/ Compiler person, can you say a bit about what this means?
Are there any using facing implications such as being able to express impure loops at the cost of sequential execution or being able to write loops that can be compiled to parallel functional primitives, based on their effect type? (Kinda like what Dex does).
Hi, for a non PL/ Compiler person, can you say a bit about what this means?
We're exploring implementing a new control-flow primitive (for
) that generalizes scan and is more flexible/expressive. We're also exploring a "state" side-effect in JAX and its ramifications. Right now, we're containing the side-effect to just in the body of this new for
primitive. For some examples of its usage, you can see the control flow test.
Are there any using facing implications such as being able to express impure loops?
Yes, potentially. The for
implementation is still in an exploratory stage but it is a generalization of scan
because we can express more than just scan
-like patterns (see the cumsum
example in the tests).
XLA doesn't have a (general) "parallel for" like Dex does, but you're exactly right about the direction we hope to go. If we imagine we're lowering to a Dex backend, we could potentially parallelize the for loop.
Hope this is helpful!
Yes that's very informative, thanks! Sounds like some cool facilities are in the pipeline.
Can we expect some performance improvements by re-implementing scan
in terms of for
?
Can we expect some performance improvements by re-implementing scan in terms of for ?
Probably not. However, some patterns are not efficiently expressed as a scan but are as a for_loop
. In these cases, you can use for_loop
and potentially see a speedup.
@sharadmv Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?
This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.
This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.
Yes, I'd prefer this discussion to happen elsewhere though I can give a brief answer here.
Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?
Yes and no. Yes, it is possible to register lowering rules for your custom primitives via jax.interpreters.mlir.register_lowering
and no, because it is an internal API and therefore there are no promises for stability and your code could be broken at any time.