jax
jax copied to clipboard
Strided indexing turns into gather
Description
Syntax such as a[1::100] where a is a JAX tensor inside a JIT appears to turn into a gather operation rather than a strided slice, at least on TPU. This is inefficient and jax.lax.slice already supports strides.
System info (python version, jaxlib version, accelerator, etc.)
N/A
Thanks for the report! Lowering contiguous slices to lax.slice was an optimization we made a while ago, and at the time we scoped the problem to contiguous slices for simplicity. Adding strided slices to the logic would require adding support for them in this utility: https://github.com/google/jax/blob/1cb69716fe5c456206c2bfd494e5a4f8fd81bbbf/jax/_src/numpy/lax_numpy.py#L5490
It seems like something that should work if normal slices work. It is surprising that a syntax supported by both Python indexing/NumPy and a feature supported by JAX doesn't work efficiently.
Agreed, that's why I marked this as a performance-related enhancement.