jax
jax copied to clipboard
`jax.lax.with_sharding_constraint` is not propagated outside jit
I would expect the 2 functions init0
and init1
to be identical:
REPLICATED = NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec())
@functools.partial(
jax.jit,
out_shardings=REPLICATED,
)
def init0():
return jnp.zeros((3,))
@jax.jit
def init1():
return jax.lax.with_sharding_constraint(jnp.zeros((3,)), REPLICATED)
However using jax.lax.with_sharding_constraint
loose the original sharding information (e.g. sharding name) outside of jit:
init0().sharding # NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec())
init1().sharding # GSPMDSharding({replicated})
The reason why jax.lax.with_sharding_constraint
is prefered is that executing out_shardings=
fail because absl.app.run
is not called when the function is defined, so jax.devices()
and creating the sharding fail.
I don't think this is related to with_sharding_constraint
. In my experience every operation applied to an object with NamedSharding
transforms it into an object with GSPMDSharding
loosing the PartitionSpec names (but still being sharded correctly).
import os
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
num_gpus = 8
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))
with mesh:
arr = jax.numpy.zeros((16,))
arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
print(arr.sharding)
arr = arr * 2
print(arr.sharding)
print(arr.sharding.devices_indices_map(tuple(arr.shape)))
jax.debug.visualize_array_sharding(arr,use_color=False)
NamedSharding(mesh={'gpus': 8}, spec=PartitionSpec('gpus',))
GSPMDSharding({devices=[8]0,1,2,3,4,5,6,7})
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
Yeah, this should ideally work and return NamedSharding.
This is a rare case with no arguments and no annotations on jit when it doesn't work.
But thanks for reporting. I'll try to fix it.
Hi @Findus23
It looks the issue mentioned by you has been resolved. I tried to reproduce the issue mentioned by you on Colab with JAX version 0.4.23. Now the operations applied to an object with NamedSharding
doesn't transform it into an object with GSPMDSharding
.
But the issue mentioned by @Conchylicultor still exists.
import os
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
num_gpus = 8
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))
with mesh:
arr = jax.numpy.zeros((16,))
arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
print(arr.sharding)
arr = arr * 2
print(arr.sharding)
arr = arr + 4
print(arr.sharding)
arr = arr - 1
print(arr.sharding)
arr = arr / 2
print(arr.sharding)
print(arr.sharding.devices_indices_map(tuple(arr.shape)))
jax.debug.visualize_array_sharding(arr,use_color=False)
Output:
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
Please find the gist for reference.
Thank you
Indeed, it seems like since 0.4.19 most operations are returning proper NamedSharding: https://github.com/Findus23/jax-array-info/commit/fd641005656c3c23f9b854fe0e7992e9a5937864