jax
jax copied to clipboard
Add Kubernetes Jobset support to `jax.distributed.initialize`
This should allow JAX to automatically bootstrap distributed computations when running inside a Kubernetes jobset.
@hawkinsp , PTAL when you can.
You need to add your new file to the list here: https://github.com/google/jax/blob/cd04d0f32e854aa754e37e4b676725655a94e731/jax/BUILD#L949
(sorted order)
to unbreak the bazel build.
added