jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.lax.with_sharding_constraint` is not propagated outside jit

Open Conchylicultor opened this issue 1 year ago • 4 comments

I would expect the 2 functions init0 and init1 to be identical:

REPLICATED = NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec())

@functools.partial(
    jax.jit,
    out_shardings=REPLICATED,
)
def init0():
  return jnp.zeros((3,))


@jax.jit
def init1():
  return jax.lax.with_sharding_constraint(jnp.zeros((3,)), REPLICATED)

However using jax.lax.with_sharding_constraint loose the original sharding information (e.g. sharding name) outside of jit:

init0().sharding  # NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec())
init1().sharding  # GSPMDSharding({replicated})

The reason why jax.lax.with_sharding_constraint is prefered is that executing out_shardings= fail because absl.app.run is not called when the function is defined, so jax.devices() and creating the sharding fail.

Conchylicultor avatar Sep 04 '23 08:09 Conchylicultor

I don't think this is related to with_sharding_constraint. In my experience every operation applied to an object with NamedSharding transforms it into an object with GSPMDSharding loosing the PartitionSpec names (but still being sharded correctly).

import os

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

num_gpus = 8

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))

with mesh:
    arr = jax.numpy.zeros((16,))
    arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
    print(arr.sharding)
    arr = arr * 2
    print(arr.sharding)
    print(arr.sharding.devices_indices_map(tuple(arr.shape)))
    jax.debug.visualize_array_sharding(arr,use_color=False)
NamedSharding(mesh={'gpus': 8}, spec=PartitionSpec('gpus',))
GSPMDSharding({devices=[8]0,1,2,3,4,5,6,7})
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

Findus23 avatar Sep 04 '23 10:09 Findus23

Yeah, this should ideally work and return NamedSharding.

This is a rare case with no arguments and no annotations on jit when it doesn't work.

But thanks for reporting. I'll try to fix it.

yashk2810 avatar Sep 04 '23 14:09 yashk2810

Hi @Findus23

It looks the issue mentioned by you has been resolved. I tried to reproduce the issue mentioned by you on Colab with JAX version 0.4.23. Now the operations applied to an object with NamedSharding doesn't transform it into an object with GSPMDSharding.

But the issue mentioned by @Conchylicultor still exists.

import os

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

num_gpus = 8

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))

with mesh:
    arr = jax.numpy.zeros((16,))
    arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
    print(arr.sharding)
    arr = arr * 2
    print(arr.sharding)
    arr = arr + 4
    print(arr.sharding)
    arr = arr - 1
    print(arr.sharding)
    arr = arr / 2
    print(arr.sharding)
    print(arr.sharding.devices_indices_map(tuple(arr.shape)))
    jax.debug.visualize_array_sharding(arr,use_color=False)

Output:

NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

Please find the gist for reference.

Thank you

rajasekharporeddy avatar Feb 27 '24 06:02 rajasekharporeddy

Indeed, it seems like since 0.4.19 most operations are returning proper NamedSharding: https://github.com/Findus23/jax-array-info/commit/fd641005656c3c23f9b854fe0e7992e9a5937864

Findus23 avatar Feb 27 '24 10:02 Findus23