jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Support typechecking of jax.sharding.NamedSharding

Open reinerp opened this issue 1 year ago • 7 comments

I love jaxtyping! Can I have more of it please?

Specifically, I'd like to make assertions about the sharding of my jax.Array objects. Given an array Float[Array, "batch seqlen channel"] I'd like to assert its sharding with syntax like this: Float[ShardedArray, "batch/data_parallel seqlen channel/tensor_parallel"]. This syntax is a commonly used plain-text representation for shardings, following e.g. the notation in Figure 5 of Efficiently Scaling Transformer Inference.

The intention is that the sharding part of this syntax would this syntax would parse to a sharding spec of jax.sharding.PartitionSpec('data_parallel', None, 'tensor_parallel'). We could then assert equivalence of this partition spec against the array's actual sharding using a combination of jax.debug.inspect_array_sharding and jax.sharding.XLACompatibleSharding.is_equivalent_to.

There's a small hiccup: to convert a jax.sharding.PartitionSpec to a jax.sharding.NamedSharding, we need a jax.sharding.Mesh, which is non-constant data (contains jax "device" objects) that is undesirable to put in a type signature. I think the best user experience would be to put this in a thread-local; perhaps even the one that JAX already uses for (now-superseded) pjit: jax._src.mesh.thread_resources.env.physical_mesh (unfortunately, this is private). In that case, the sharding assertion could look like this:

import jax._src.mesh as mesh_private
import functools

def _assert_sharding_cb(ndim: int, expected: jax.sharding.XLACompatibleSharding, actual: jax.sharding.XLACompatibleSharding):
    if not expected.is_equivalent_to(actual, ndim):
      raise ValueError(f'got sharding {actual}, but expected {expected}')


def assert_sharding(v: jax.Array, expected: jax.sharding.PartitionSpec):
  mesh = mesh_private.thread_resources.env.physical_mesh
  expected_sharding = jax.sharding.NamedSharding(mesh, expected)
  jax.debug.inspect_array_sharding(v, callback=functools.partial(_assert_sharding_cb, v.ndim, expected_sharding))

Complete colab that tries this out on 8 CPUs, and shows that it works under jit too:: https://colab.research.google.com/drive/1oLy66BjKOWmh7dFu8aZbo_gBypDtlNeQ?usp=sharing.

reinerp avatar Feb 04 '24 18:02 reinerp