xla
xla copied to clipboard
[GPU] Optimize all-gathers on non-major dimension using a single transpose.
๐ Summary of Changes All-gathers can only run on the major-most physical dimension - concatenating buffers from ranks. When an all-gather on a logical dimension index > 0 is requested, layout assignment assigns such layouts to the input and the output, that the gathered dimension becomes the physically major-most. This results in two transposes, one before and one after.
Instead one can change such all-gathers to ones that gather on the dimension 0 concatenating the inputs in their original layout and then use a single transpose to get back to the requested output layout.
๐ฏ Justification Performance improvement
๐ Kind of Contribution โก๏ธ Performance Improvement
๐ Benchmark (for Performance Improvements) Before: Execution time, file=xla/xla/tools/benchmarks/hlo/u4_all_gather_1x8.hlo repeat=1 duration=42080ns Execution time, file=xla/xla/tools/benchmarks/hlo/u4_all_gather_1x8.hlo repeat=2 duration=40704ns Execution time, file=xla/xla/tools/benchmarks/hlo/u4_all_gather_1x8.hlo repeat=3 duration=40672ns
After: Execution time, file=xla/xla/tools/benchmarks/hlo/u4_all_gather_1x8.hlo repeat=1 duration=35840ns Execution time, file=xla/xla/tools/benchmarks/hlo/u4_all_gather_1x8.hlo repeat=2 duration=35616ns Execution time, file=xla/xla/tools/benchmarks/hlo/u4_all_gather_1x8.hlo repeat=3 duration=35296ns
๐งช Unit Tests: Yes.
๐งช Execution Tests: Yes.
Running benchmarks, will report back when I have results.
Some important benchmarks show regressions in the 10% range. The main difference I spotted is that this PR causes many extra wrapped_transpose kernels between all-gather-done ops and cuBLAS GEMMs.
I redesigned the pass such that it runs after layout assignment and does not prevent folding of transposes into other ops surrounding all-gathers. Please take another look.