jax
jax copied to clipboard
Corner-case sharding regression when replacing concrete mesh with abstract mesh
Description
With no data dependency on inputs, abstract mesh and concrete mesh in jax.lax.with_sharding_constraint
result in different behavior, with concrete mesh behaving as expected.
Sending a PR with this test, and will tag this issue, but reproduction is essentially:
def f(x):
# This breaks the data dependency on x.
x = with_sharding_constraint(
jnp.ones(shape=x.shape, dtype=x.dtype),
NamedSharding(mesh, P('x')),
)
return x * 2
When mesh is abstract, the resulting array is not sharded as expected; when mesh is concrete, sharding is WAI.
System info (python version, jaxlib version, accelerator, etc.)
Internal Google test.