jax icon indicating copy to clipboard operation
jax copied to clipboard

generate dynamic_slice rather than slice for simple indexing/slicing

Open jakevdp opened this issue 3 years ago • 6 comments

Followup to #11867, fixes #12198

jakevdp avatar Sep 02 '22 22:09 jakevdp

Want to add a benchmark to api_benchmarks.py so we don't regress this performance issue in the future?

mattjj avatar Sep 02 '22 22:09 mattjj

Hmm, since this is an XLA cacheing thing, I don't think our benchmark framework as currently written can catch it, because after the first execution subsequent executions will be fast.

jakevdp avatar Sep 02 '22 22:09 jakevdp

We might be able to follow the pattern in these benchmarks to foil caching.

mattjj avatar Sep 02 '22 23:09 mattjj

I'm not following – does state.__iter__ do some sort of XLA-level cache clear?

jakevdp avatar Sep 06 '22 18:09 jakevdp

I don't see any easy way to benchmark this in our benchmark framework, because (if my understanding is correct) it would require clearing XLA's internal cache on each iteration, and I don't know how to do that short of restarting the runtime, which cannot easily be done here. I'm going to proceed with landing this change without the benchmark.

jakevdp avatar Sep 07 '22 20:09 jakevdp

We're blocked by two jax.experimental.jet issues:

  • https://github.com/google/jax/issues/12263
  • missing jet rule for dynamic_update_slice_p

jakevdp avatar Sep 07 '22 21:09 jakevdp

Added a second commit with benchmarks to prevent regression.

Before:

name                            cpu/op
bench_repeated_static_indexing  11.1s ±37%
bench_repeated_static_slicing   9.68s ±39%

name                            time/op
bench_repeated_static_indexing  11.2s ±38%
bench_repeated_static_slicing   9.72s ±40%

name                            allocs/op
bench_repeated_static_indexing   0.00 ±NaN%
bench_repeated_static_slicing    0.00 ±NaN%

name                            peak-mem(Bytes)/op
bench_repeated_static_indexing   0.00 ±NaN%
bench_repeated_static_slicing    0.00 ±NaN%

After:

name                            cpu/op
bench_repeated_static_indexing  206ms ± 3%
bench_repeated_static_slicing   152ms ± 2%

name                            time/op
bench_repeated_static_indexing  208ms ± 3%
bench_repeated_static_slicing   153ms ± 3%

name                            allocs/op
bench_repeated_static_indexing   0.00 ±NaN%
bench_repeated_static_slicing    0.00 ±NaN%

name                            peak-mem(Bytes)/op
bench_repeated_static_indexing   0.00 ±NaN%
bench_repeated_static_slicing    0.00 ±NaN%

jakevdp avatar Nov 03 '22 18:11 jakevdp