Implement cross host transfer optimizations in StreamExecutorGpuClient
📝 Summary of Changes
This PR adds three optimizations to the CrossHost{Send,Receive}Buffers methods in StreamExecutorGpuClient:
- Communicators are cached for re-use across transfers with
AcquireGpuClique. - Transfers of multiple arrays are aggregated with NCCL group calls using
GpuCommunicator::GroupExecute. - Usage / definition events and promises that signal the completion of cross-host transfers are fulfilled on
StreamExecutorGpuClient::thread_pool()instead of blocking the XLA execute thread.
This PR follows up on #33284 and is based on discussions with @emilyfertig, @mwhittaker, and @pschuh.
🎯 Justification
#33284 added CrossHost{Send,Receive}Buffers to the PjRt Client API to enable the optimizations implemented in this PR. The old implementation does not cache communicators, aggregate transfers, or asynchronously fulfill events / promises, leading to reduced performance.
🚀 Kind of Contribution ⚡️ Performance Improvement
📊 Benchmark (for Performance Improvements) Benchmarks were ran on these three toy programs.
| Workload | Implementation | Time |
|---|---|---|
| example_basic_overlap | Unoptimized | 26.968 sec |
| Optimized | 43.124 ms | |
| example_overlap | Unoptimized | 369.541 sec |
| Optimized | 733.071 ms | |
| example_pp | Unoptimized | 369.0924 sec |
| Optimized | 2.433 sec |
🧪 Unit Tests: The unit tests added in #33284 continue to pass. The new implementation requires that cross-host transfers are always between a local device and a remote device; a test has been added to make sure that the implementation throws an InvalidArgument error if this is not the case.
🧪 Execution Tests: When building XLA with the code in this PR along with a patch for pjrt_ifrt/pjrt_client.cc so that it uses the new cross-host transfers API, I verified that these 4 correctness checks pass. These IFRT changes will be included in a follow-up PR once other PjRt clients have implemented the new API.