jax icon indicating copy to clipboard operation
jax copied to clipboard

Corner-case sharding regression when replacing concrete mesh with abstract mesh

Open jkr26 opened this issue 5 months ago • 2 comments

Description

With no data dependency on inputs, abstract mesh and concrete mesh in jax.lax.with_sharding_constraint result in different behavior, with concrete mesh behaving as expected.

Sending a PR with this test, and will tag this issue, but reproduction is essentially:

def f(x):
  # This breaks the data dependency on x.
  x = with_sharding_constraint(
      jnp.ones(shape=x.shape, dtype=x.dtype),
      NamedSharding(mesh, P('x')),
  )
  return x * 2

When mesh is abstract, the resulting array is not sharded as expected; when mesh is concrete, sharding is WAI.

System info (python version, jaxlib version, accelerator, etc.)

Internal Google test.

jkr26 avatar Sep 13 '24 20:09 jkr26