Question about autoparallelism
Hi!
There is something that I do not understand in the Autoparallelism example: eqx.filter_shard is used both outside and inside the jit-ted function make_step.
Is the purpose of the inside call to help the compiler know how Pytrees should be sharded (as described in JAX docs)? If so, can I assume that it's optional?
I am asking because XLA raises an error when sharding an Array of PRNG keys. See this MWE for example:
import equinox as eqx
import jax
import jax.random as jrd
import jax.sharding as jshard
num_devices = len(jax.devices()) # 2 in my case
mesh = jax.make_mesh((num_devices,), ("batch",))
sharded = jshard.NamedSharding(mesh, jshard.PartitionSpec("batch"))
batch_size = 64
key = jrd.key(0)
batch_key, key = jrd.split(key, 2)
key_batch = jrd.split(batch_key, batch_size)
key_batch = eqx.filter_shard(key_batch, sharded)
@eqx.filter_jit
def fun1(_key_batch):
uniform = eqx.filter_vmap(jax.random.uniform)(_key_batch)
return uniform
@eqx.filter_jit
def fun2(_key_batch, _sharded):
_key_batch = eqx.filter_shard(_key_batch, _sharded)
uniform = eqx.filter_vmap(jax.random.uniform)(_key_batch)
uniform = eqx.filter_shard(uniform, _sharded)
return uniform
uni1 = fun1(key_batch) # works
jax.debug.visualize_array_sharding(uni1) # sharded on both GPUs
uni2 = fun2(key_batch, sharded) # raises jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: Number of tile assignment dimensions (excluding subgroups) is different than the input rank. sharding={devices=[2]<=[2]}, input_shape=u32[64,2]
Thank you for the clarification. Vadim
So filter_shard is just a thin wrapper around wsc:
https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/_sharding.py#L40
it's designed to make it (a) easier to handle static arguments and (b) to unify device/sharding information.
So indeed it's designed to help the compiler know how to do sharding. Using it outside the JIT'd region should actually move the data to the other devices, which is needed for the jax.jitd function to then execute on those devices.
Thank you for the explaination!
Any idea why XLA complains when using filter_shard or lax.with_sharding_constraint inside a JIT'd function with PRNG keys? Maybe it's more of a question/issue for JAX or XLA folks...
Any idea why XLA complains when using filter_shard or lax.with_sharding_constraint inside a JIT'd function with PRNG keys? Maybe it's more of a question/issue for JAX or XLA folks...
That I'm not super sure about I'm afraid -- indeed probably question for the JAX folks!