equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Question about autoparallelism

Open vadmbertr opened this issue 10 months ago • 3 comments

Hi!

There is something that I do not understand in the Autoparallelism example: eqx.filter_shard is used both outside and inside the jit-ted function make_step.

Is the purpose of the inside call to help the compiler know how Pytrees should be sharded (as described in JAX docs)? If so, can I assume that it's optional?

I am asking because XLA raises an error when sharding an Array of PRNG keys. See this MWE for example:

import equinox as eqx
import jax
import jax.random as jrd
import jax.sharding as jshard

num_devices = len(jax.devices())  # 2 in my case
mesh = jax.make_mesh((num_devices,), ("batch",))
sharded = jshard.NamedSharding(mesh, jshard.PartitionSpec("batch"))

batch_size = 64
key = jrd.key(0)
batch_key, key = jrd.split(key, 2)
key_batch = jrd.split(batch_key, batch_size)

key_batch = eqx.filter_shard(key_batch, sharded)

@eqx.filter_jit
def fun1(_key_batch):
    uniform = eqx.filter_vmap(jax.random.uniform)(_key_batch)
    return uniform

@eqx.filter_jit
def fun2(_key_batch, _sharded):
    _key_batch = eqx.filter_shard(_key_batch, _sharded)
    uniform = eqx.filter_vmap(jax.random.uniform)(_key_batch)
    uniform = eqx.filter_shard(uniform, _sharded)
    return uniform

uni1 = fun1(key_batch)  # works
jax.debug.visualize_array_sharding(uni1)  # sharded on both GPUs

uni2 = fun2(key_batch, sharded)  # raises jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: Number of tile assignment dimensions (excluding subgroups) is different than the input rank. sharding={devices=[2]<=[2]}, input_shape=u32[64,2]

Thank you for the clarification. Vadim

vadmbertr avatar Mar 17 '25 10:03 vadmbertr

So filter_shard is just a thin wrapper around wsc:

https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/_sharding.py#L40

it's designed to make it (a) easier to handle static arguments and (b) to unify device/sharding information.

So indeed it's designed to help the compiler know how to do sharding. Using it outside the JIT'd region should actually move the data to the other devices, which is needed for the jax.jitd function to then execute on those devices.

patrick-kidger avatar Mar 17 '25 15:03 patrick-kidger

Thank you for the explaination! Any idea why XLA complains when using filter_shard or lax.with_sharding_constraint inside a JIT'd function with PRNG keys? Maybe it's more of a question/issue for JAX or XLA folks...

vadmbertr avatar Mar 18 '25 07:03 vadmbertr

Any idea why XLA complains when using filter_shard or lax.with_sharding_constraint inside a JIT'd function with PRNG keys? Maybe it's more of a question/issue for JAX or XLA folks...

That I'm not super sure about I'm afraid -- indeed probably question for the JAX folks!

patrick-kidger avatar Mar 18 '25 16:03 patrick-kidger