jax icon indicating copy to clipboard operation
jax copied to clipboard

Tracker: decomposing scan (aka "Five Loop")

Open sharadmv opened this issue 2 years ago • 7 comments

With jaxprs now supporting effect types, we can express side-effects like the State monad, where references can be read from and written to (i.e. mutation). We can use state to implement a simpler scan control flow primitive via a for primitive that supports reads/writes.

This issue will track the implementation progress:

  • [x] get/swap/addupdate primitives
    • [x] impl rules
    • [x] abstract_eval rules
    • [x] jvp rules
    • [x] transpose rules
    • [x] vmap rules
  • [x] Discharging state
    • [x] Basic implementation
    • [x] Handling higher-order primitives
  • [ ] for primitive
    • [x] impl rule
    • [x] abstract_eval rule
    • [x] MLIR lowering
    • [x] jvp rule
    • [ ] partial_eval
      • [x] basic implementation
      • [ ] optimizations
        • [x] loop invariant
        • [x] ~make loop index a ref and~ use read/writes to determine which values are loop-invariant
        • [ ] residual passthrough
        • [ ] rematerializing loop-dependent values
    • [x] transpose
    • [x] vmap rule
    • [x] partial_eval_custom rule
    • [x] Miscellaneous
      • [x] Handling closed-over refs
      • [x] Nested for loops
      • [x] Unrolling
    • [x] Reimplement scan in terms of for

The "raw" version of for can be found here. Next steps involve porting that code to JAX core and adding tests.

sharadmv avatar Jun 04 '22 01:06 sharadmv

Hi, for a non PL/ Compiler person, can you say a bit about what this means?

Are there any using facing implications such as being able to express impure loops at the cost of sequential execution or being able to write loops that can be compiled to parallel functional primitives, based on their effect type? (Kinda like what Dex does).

AriMKatz avatar Jul 05 '22 19:07 AriMKatz

Hi, for a non PL/ Compiler person, can you say a bit about what this means?

We're exploring implementing a new control-flow primitive (for) that generalizes scan and is more flexible/expressive. We're also exploring a "state" side-effect in JAX and its ramifications. Right now, we're containing the side-effect to just in the body of this new for primitive. For some examples of its usage, you can see the control flow test.

Are there any using facing implications such as being able to express impure loops?

Yes, potentially. The for implementation is still in an exploratory stage but it is a generalization of scan because we can express more than just scan-like patterns (see the cumsum example in the tests).

XLA doesn't have a (general) "parallel for" like Dex does, but you're exactly right about the direction we hope to go. If we imagine we're lowering to a Dex backend, we could potentially parallelize the for loop.

Hope this is helpful!

sharadmv avatar Jul 05 '22 23:07 sharadmv

Yes that's very informative, thanks! Sounds like some cool facilities are in the pipeline.

AriMKatz avatar Jul 06 '22 18:07 AriMKatz

Can we expect some performance improvements by re-implementing scan in terms of for ?

LoicRaillon avatar Jul 07 '22 07:07 LoicRaillon

Can we expect some performance improvements by re-implementing scan in terms of for ?

Probably not. However, some patterns are not efficiently expressed as a scan but are as a for_loop. In these cases, you can use for_loop and potentially see a speedup.

sharadmv avatar Aug 29 '22 21:08 sharadmv

@sharadmv Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?

This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.

femtomc avatar Aug 31 '22 01:08 femtomc

This is a bit of a side discussion question, so I can make a discussion if you'd like to move it over there.

Yes, I'd prefer this discussion to happen elsewhere though I can give a brief answer here.

Is it possible to define new MLIR lowering rules for primitives outside of the core JAX repo? Or: is that part of JAX user-accessible?

Yes and no. Yes, it is possible to register lowering rules for your custom primitives via jax.interpreters.mlir.register_lowering and no, because it is an internal API and therefore there are no promises for stability and your code could be broken at any time.

sharadmv avatar Aug 31 '22 03:08 sharadmv