Aggregate cross-host data transfers made by jax.device_put
This PR modifies _batched_device_put_impl to batch the cross-host data transfers of multiple arrays. This enables us to take advantage of recent optimizations inside XLA for cross-host data transfers, particularly on GPUs (XLA #33284, XLA #33910, XLA #34904). CrossHostTransferTest.test_cross_host_transfer_batched is lightly modified and enabled for the CUDA platform to test this change.
Summary of Changes
Hello @rao-ashish, I'm Gemini Code Assist[^1]! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the performance of jax.device_put by implementing a mechanism to aggregate and batch cross-host data transfers. By deferring individual transfer requests and executing them collectively, the system can now leverage advanced optimizations within XLA, particularly benefiting GPU-based computations. This change streamlines data movement across hosts, leading to more efficient execution of JAX programs in distributed environments.
Highlights
- Batching Cross-Host Transfers: The
jax.device_putfunction now batches multiple cross-host data transfers, improving overall efficiency. - Leveraging XLA Optimizations: This change takes advantage of recent XLA optimizations for batched cross-host data transfers, particularly benefiting GPU environments.
- Deferred Transfer Mechanism: A new dataclass,
_DeferredCrossHostTransferArg, has been introduced to temporarily hold cross-host transfer requests, allowing them to be aggregated before execution. - Updated
_batched_device_put_impl: The core_batched_device_put_implfunction was modified to collect these deferred transfer objects and execute them in a single batched call. - CUDA Test Enabled: The
CrossHostTransferTest.test_cross_host_transfer_batchedtest case was updated and enabled for the CUDA platform to validate the new batched transfer behavior.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with :thumbsup: and :thumbsdown: on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
[^1]: Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.
Thanks for reviewing this! I looked into the pxla.shard_args code, and I think we will need to handle cross-host transfers separately by directly calling batched_copy_array_to_devices_with_sharding.
pxla.shard_args for jax Arrays eventually calls _shard_np_array for each array, which breaks the batching of transfers. Further, this calls into jaxlib's batched_device_put method with a batch comprised of shards local to a single array. batched_device_put ultimately calls into DevicePutWithSharding inside py_values.cc, which also specializes in the cross-host case to call ifrt_client->CopyArrays(...) for one shard at a time. So, the most feasible way to implement batching would be to explicitly call xc.batched_copy_array_to_devices_with_sharding like the original code was doing.
However, I've removed the new dataclass and added a field inside _DeferredShardArgs to designate cross-host transfers as you suggested. I've also introduced a function _populate_batched_results to reduce code duplication between the two cases.
I think we can merge and simplify more. Let me patch this change in internally and try things out.