jax
jax copied to clipboard
generate dynamic_slice rather than slice for simple indexing/slicing
Followup to #11867, fixes #12198
Want to add a benchmark to api_benchmarks.py so we don't regress this performance issue in the future?
Hmm, since this is an XLA cacheing thing, I don't think our benchmark framework as currently written can catch it, because after the first execution subsequent executions will be fast.
We might be able to follow the pattern in these benchmarks to foil caching.
I'm not following – does state.__iter__ do some sort of XLA-level cache clear?
I don't see any easy way to benchmark this in our benchmark framework, because (if my understanding is correct) it would require clearing XLA's internal cache on each iteration, and I don't know how to do that short of restarting the runtime, which cannot easily be done here. I'm going to proceed with landing this change without the benchmark.
We're blocked by two jax.experimental.jet issues:
- https://github.com/google/jax/issues/12263
- missing jet rule for
dynamic_update_slice_p
Added a second commit with benchmarks to prevent regression.
Before:
name cpu/op
bench_repeated_static_indexing 11.1s ±37%
bench_repeated_static_slicing 9.68s ±39%
name time/op
bench_repeated_static_indexing 11.2s ±38%
bench_repeated_static_slicing 9.72s ±40%
name allocs/op
bench_repeated_static_indexing 0.00 ±NaN%
bench_repeated_static_slicing 0.00 ±NaN%
name peak-mem(Bytes)/op
bench_repeated_static_indexing 0.00 ±NaN%
bench_repeated_static_slicing 0.00 ±NaN%
After:
name cpu/op
bench_repeated_static_indexing 206ms ± 3%
bench_repeated_static_slicing 152ms ± 2%
name time/op
bench_repeated_static_indexing 208ms ± 3%
bench_repeated_static_slicing 153ms ± 3%
name allocs/op
bench_repeated_static_indexing 0.00 ±NaN%
bench_repeated_static_slicing 0.00 ±NaN%
name peak-mem(Bytes)/op
bench_repeated_static_indexing 0.00 ±NaN%
bench_repeated_static_slicing 0.00 ±NaN%