jax icon indicating copy to clipboard operation
jax copied to clipboard

Strided indexing turns into gather

Open jewillco opened this issue 1 year ago • 3 comments

Description

Syntax such as a[1::100] where a is a JAX tensor inside a JIT appears to turn into a gather operation rather than a strided slice, at least on TPU. This is inefficient and jax.lax.slice already supports strides.

System info (python version, jaxlib version, accelerator, etc.)

N/A

jewillco avatar May 09 '24 20:05 jewillco

Thanks for the report! Lowering contiguous slices to lax.slice was an optimization we made a while ago, and at the time we scoped the problem to contiguous slices for simplicity. Adding strided slices to the logic would require adding support for them in this utility: https://github.com/google/jax/blob/1cb69716fe5c456206c2bfd494e5a4f8fd81bbbf/jax/_src/numpy/lax_numpy.py#L5490

jakevdp avatar May 09 '24 20:05 jakevdp

It seems like something that should work if normal slices work. It is surprising that a syntax supported by both Python indexing/NumPy and a feature supported by JAX doesn't work efficiently.

jewillco avatar May 09 '24 20:05 jewillco

Agreed, that's why I marked this as a performance-related enhancement.

jakevdp avatar May 09 '24 20:05 jakevdp