Jax 0.4.27 parallelism errors
The latest version of jax seems to break things when you parallelize, both pmap and sharding have the same error. Here is MVC:
import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
from diffrax import *
def f(t, y, args):
return jnp.sin(t) + args["theta"] * y
t0 = 0.
t1 = 0.04
dt0 = 0.02
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)
def solve(init, key, args):
vf = ODETerm(f)
terms = vf
ts = jnp.linspace(t0, t1, 100)
saving = SaveAt(ts=ts)
sol = diffeqsolve(
terms,
solver,
y0=init,
t0=t0,
t1=t1,
dt0=dt0,
args=args,
saveat=saving,
stepsize_controller=cont,
)
return sol.ys
batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])
args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
print('shard')
_ = fn(x, y, args_shard).block_until_ready()
print('pmap')
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-14-d7a89459f8a5>](https://localhost:8080/#) in <cell line: 59>()
57 fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
58 print('shard')
---> 59 _ = fn(x, y, args_shard).block_until_ready()
60 print('pmap')
61 _ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()
[... skipping hidden 20 frame]
11 frames
[<ipython-input-14-d7a89459f8a5>](https://localhost:8080/#) in solve(init, key, args)
27 ts = jnp.linspace(t0, t1, 100)
28 saving = SaveAt(ts=ts)
---> 29 sol = diffeqsolve(
30 terms,
31 solver,
[... skipping hidden 15 frame]
[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, discrete_terminating_event, max_steps, throw, solver_state, controller_state, made_jump)
914 #
915
--> 916 final_state, aux_stats = adjoint.loop(
917 args=args,
918 terms=terms,
[... skipping hidden 1 frame]
[/usr/local/lib/python3.10/dist-packages/diffrax/_adjoint.py](https://localhost:8080/#) in loop(***failed resolving arguments***)
286 )
287 msg = None
--> 288 final_state = self._loop(
289 terms=terms,
290 saveat=saveat,
[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in loop(solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop)
437 static_made_jump = init_state.made_jump
438 static_result = init_state.result
--> 439 _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
440 if traced_jump:
441 static_made_jump = None
[... skipping hidden 14 frame]
[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in body_fun_aux(state)
238 #
239
--> 240 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
241 terms,
242 state.tprev,
[... skipping hidden 1 frame]
[/usr/local/lib/python3.10/dist-packages/diffrax/_solver/runge_kutta.py](https://localhost:8080/#) in step(***failed resolving arguments***)
1147 # "triangular computations" (every stage depends on all previous stages)
1148 # without spurious copies.
-> 1149 final_val = eqxi.while_loop(
1150 cond_stage,
1151 rk_stage,
[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/loop.py](https://localhost:8080/#) in while_loop(***failed resolving arguments***)
105 elif kind == "checkpointed":
106 del kind, base
--> 107 return checkpointed_while_loop(
108 cond_fun,
109 body_fun,
[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in checkpointed_while_loop(***failed resolving arguments***)
247 body_fun_ = filter_closure_convert(body_fun_, init_val_)
248 vjp_arg = (init_val_, body_fun_)
--> 249 final_val_ = _checkpointed_while_loop(
250 vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
251 )
[... skipping hidden 8 frame]
[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in _checkpointed_while_loop(***failed resolving arguments***)
268 _body_fun = lambda x: body_fun(x) # hashable wrapper; JAX issue #13554
269 while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
--> 270 return while_loop(cond_fun, _body_fun, init_val)
271
272
[/usr/lib/python3.10/contextlib.py](https://localhost:8080/#) in inner(*args, **kwds)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
80 return inner
81
[... skipping hidden 9 frame]
[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in <lambda>(x)
266 del checkpoints, buffers, max_steps
267 init_val, body_fun = vjp_arg
--> 268 _body_fun = lambda x: body_fun(x) # hashable wrapper; JAX issue #13554
269 while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
270 return while_loop(cond_fun, _body_fun, init_val)
[... skipping hidden 1 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in eval_jaxpr(jaxpr, consts, propagate_source_info, *args)
447 env: dict[Var, Any] = {}
448 map(write, jaxpr.constvars, consts)
--> 449 map(write, jaxpr.invars, args)
450 lu = last_used(jaxpr)
451 for eqn in jaxpr.eqns:
ValueError: safe_map() argument 2 is shorter than argument 1
just took the code from https://github.com/patrick-kidger/diffrax/issues/407
Possible this is an error in equinox, but I wasn't able to exactly replicate it without diffrax, e.g. works fine
import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import equinox.internal as eqxi
import functools as ft
def f(t, y, theta):
return jnp.abs(jnp.sin(t)) + theta * y
_inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop")
_outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop")
def solve(init, key):
def inner_loop_cond(state):
t, y, _ = state
return y.squeeze() < 10
def inner_loop_body(state):
t, y, theta = state
dy = f(t, y, theta)
return (t + 0.1, y + 0.1 * dy, theta)
def outer_loop_cond(state):
_, _, _, count = state
return count < 5
def outer_loop_body(state):
t, y, theta, count = state
y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)
inner_while_loop = ft.partial(_inner_loop, kind="lax")
outer_while_loop = ft.partial(_outer_loop, kind="lax")
theta = 5.0
t_initial = 0.0
y_initial = init
count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
return final_state[1]
batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve))
pmap_fn = eqx.filter_pmap(fn)
print('shard')
_ = fn(x, y).block_until_ready()
print('pmap')
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()
Issue that the error code references in stack trace: https://github.com/google/jax/issues/13554
All of the above code works 100% fine in 0.4.26 btw (well the sharding is still slower, but that's a different issue)
Thanks for the report! Looks like an upstream JAX bug. I've opened https://github.com/google/jax/issues/21116.
Great, closing! 0.4.28 fixed