[replica-parallel] Configurable limits for replica-parallel checkpoint saving
Adds parameters min_slice_bytes_for_replica_parallel and max_replicas_for_replica_parallel to ArrayHandler to allow users to configure replica-parallel checkpoint saving. If use_replica_parallel is set to true, saving will be parallelized over at most max_replicas_for_replica_parallel different replicas. If the bytes per replica slice is less than min_slice_bytes_for_replica_parallel, replica-parallel saving is disabled.
Motivation: JAX arrays may consist of multiple shards, which may in turn be duplicated across multiple ‘replicas’. #1319 and #1320 added support for parallelizing the saving of JAX array shards over these replicas. However, this parallelism combined with a high DP factor implies a large number of small writes, which in turn can cause storage backends to throttle write requests. This PR introduces configurable limits on when to use replica parallel saving and over how many replicas saving should be parallelized.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.