jax
jax copied to clipboard
generate lax.slice instead of lax.gather for x[<int>] or x[:<int>]
Before, on CPU:
After, on CPU:
Following up on #11866.
I don't know if these small compile time differences matter, but hey, roofshots! And the jaxpr pretty-printing win is enough for me.