jax
jax copied to clipboard
simplify slicing jaxprs a little
Before:
After:
No more lt
, add
, select
, or convert_element_type
!
This might actually improve compilation time a bit; on CPU, the tiny benchmark added went from ~15.5ms to 13ms on my machine.