jax
jax copied to clipboard
[export] grad of exported function fails with device assignment error
Description
The jax.experimental.export
exports the VJP using a synthetic mesh using the first N devices for the export platform. This seemed reasonable because all that is captured in the exported artifact is the number of devices, not their ids. However, during lowering there may be a conflict between the order of devices used by exporting code and shardings present in the primal functions.
For example, the following code fails:
def f(x):
return jnp.sum(x * 2.)
mesh_rev = Mesh(list(reversed(jax.local_devices())), "i")
shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",)))
input_no_shards = jnp.ones(shape=(jax.local_device_count(),))
input_rev = jax.device_put(input_no_shards, device=shardings_rev)
exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)
g = jax.grad(export.call(exp_rev))(input_rev) # Failure here
The failure is:
E ValueError: Received incompatible devices for pjitted computation. Got ARG_SHARDING with device ids [0, 1] on platform CPU and pjit inside pjit with device ids [1, 0] on platform CPU at /Users/necula/Source/jax/tests/export_test.py:1091:14 (JaxExportTest.test_grad_sharding_different_mesh)
Instead, the exporting of the VJP should use the same device assignment as for exporting the primal function.
System info (python version, jaxlib version, accelerator, etc.)
Not needed