jaxtyping
jaxtyping copied to clipboard
Support typechecking of jax.sharding.NamedSharding
I love jaxtyping! Can I have more of it please?
Specifically, I'd like to make assertions about the sharding of my jax.Array
objects. Given an array Float[Array, "batch seqlen channel"]
I'd like to assert its sharding with syntax like this: Float[ShardedArray, "batch/data_parallel seqlen channel/tensor_parallel"]
. This syntax is a commonly used plain-text representation for shardings, following e.g. the notation in Figure 5 of Efficiently Scaling Transformer Inference.
The intention is that the sharding part of this syntax would this syntax would parse to a sharding spec of jax.sharding.PartitionSpec('data_parallel', None, 'tensor_parallel')
. We could then assert equivalence of this partition spec against the array's actual sharding using a combination of jax.debug.inspect_array_sharding
and jax.sharding.XLACompatibleSharding.is_equivalent_to
.
There's a small hiccup: to convert a jax.sharding.PartitionSpec
to a jax.sharding.NamedSharding
, we need a jax.sharding.Mesh
, which is non-constant data (contains jax "device" objects) that is undesirable to put in a type signature. I think the best user experience would be to put this in a thread-local; perhaps even the one that JAX already uses for (now-superseded) pjit: jax._src.mesh.thread_resources.env.physical_mesh
(unfortunately, this is private). In that case, the sharding assertion could look like this:
import jax._src.mesh as mesh_private
import functools
def _assert_sharding_cb(ndim: int, expected: jax.sharding.XLACompatibleSharding, actual: jax.sharding.XLACompatibleSharding):
if not expected.is_equivalent_to(actual, ndim):
raise ValueError(f'got sharding {actual}, but expected {expected}')
def assert_sharding(v: jax.Array, expected: jax.sharding.PartitionSpec):
mesh = mesh_private.thread_resources.env.physical_mesh
expected_sharding = jax.sharding.NamedSharding(mesh, expected)
jax.debug.inspect_array_sharding(v, callback=functools.partial(_assert_sharding_cb, v.ndim, expected_sharding))
Complete colab that tries this out on 8 CPUs, and shows that it works under jit
too:: https://colab.research.google.com/drive/1oLy66BjKOWmh7dFu8aZbo_gBypDtlNeQ?usp=sharing.