jax
jax copied to clipboard
simplify slicing jaxprs a little
Before:

After:

No more lt, add, select, or convert_element_type!
This might actually improve compilation time a bit; on CPU, the tiny benchmark added went from ~15.5ms to 13ms on my machine.