diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Jax 0.4.27 parallelism errors

Open lockwo opened this issue 1 year ago • 4 comments

The latest version of jax seems to break things when you parallelize, both pmap and sharding have the same error. Here is MVC:

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
from diffrax import *

def f(t, y, args):
    return jnp.sin(t) + args["theta"] * y

t0 = 0.
t1 = 0.04
dt0 = 0.02
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)

def solve(init, key, args):
    vf = ODETerm(f)
    terms = vf
    ts = jnp.linspace(t0, t1, 100)
    saving = SaveAt(ts=ts)
    sol = diffeqsolve(
        terms,
        solver,
        y0=init,
        t0=t0,
        t1=t1,
        dt0=dt0,
        args=args,
        saveat=saving,
        stepsize_controller=cont,
    )
    return sol.ys

batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
print('shard')
_ = fn(x, y, args_shard).block_until_ready()
print('pmap')
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-14-d7a89459f8a5>](https://localhost:8080/#) in <cell line: 59>()
     57 fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
     58 print('shard')
---> 59 _ = fn(x, y, args_shard).block_until_ready()
     60 print('pmap')
     61 _ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

    [... skipping hidden 20 frame]

11 frames
[<ipython-input-14-d7a89459f8a5>](https://localhost:8080/#) in solve(init, key, args)
     27     ts = jnp.linspace(t0, t1, 100)
     28     saving = SaveAt(ts=ts)
---> 29     sol = diffeqsolve(
     30         terms,
     31         solver,

    [... skipping hidden 15 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, discrete_terminating_event, max_steps, throw, solver_state, controller_state, made_jump)
    914     #
    915 
--> 916     final_state, aux_stats = adjoint.loop(
    917         args=args,
    918         terms=terms,

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_adjoint.py](https://localhost:8080/#) in loop(***failed resolving arguments***)
    286             )
    287             msg = None
--> 288         final_state = self._loop(
    289             terms=terms,
    290             saveat=saveat,

[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in loop(solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, inner_while_loop, outer_while_loop)
    437     static_made_jump = init_state.made_jump
    438     static_result = init_state.result
--> 439     _, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
    440     if traced_jump:
    441         static_made_jump = None

    [... skipping hidden 14 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_integrate.py](https://localhost:8080/#) in body_fun_aux(state)
    238         #
    239 
--> 240         (y, y_error, dense_info, solver_state, solver_result) = solver.step(
    241             terms,
    242             state.tprev,

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/diffrax/_solver/runge_kutta.py](https://localhost:8080/#) in step(***failed resolving arguments***)
   1147         #     "triangular computations" (every stage depends on all previous stages)
   1148         #     without spurious copies.
-> 1149         final_val = eqxi.while_loop(
   1150             cond_stage,
   1151             rk_stage,

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/loop.py](https://localhost:8080/#) in while_loop(***failed resolving arguments***)
    105     elif kind == "checkpointed":
    106         del kind, base
--> 107         return checkpointed_while_loop(
    108             cond_fun,
    109             body_fun,

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in checkpointed_while_loop(***failed resolving arguments***)
    247     body_fun_ = filter_closure_convert(body_fun_, init_val_)
    248     vjp_arg = (init_val_, body_fun_)
--> 249     final_val_ = _checkpointed_while_loop(
    250         vjp_arg, cond_fun_, checkpoints, buffers_, max_steps
    251     )

    [... skipping hidden 8 frame]

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in _checkpointed_while_loop(***failed resolving arguments***)
    268     _body_fun = lambda x: body_fun(x)  # hashable wrapper; JAX issue #13554
    269     while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
--> 270     return while_loop(cond_fun, _body_fun, init_val)
    271 
    272 

[/usr/lib/python3.10/contextlib.py](https://localhost:8080/#) in inner(*args, **kwds)
     77         def inner(*args, **kwds):
     78             with self._recreate_cm():
---> 79                 return func(*args, **kwds)
     80         return inner
     81 

    [... skipping hidden 9 frame]

[/usr/local/lib/python3.10/dist-packages/equinox/internal/_loop/checkpointed.py](https://localhost:8080/#) in <lambda>(x)
    266     del checkpoints, buffers, max_steps
    267     init_val, body_fun = vjp_arg
--> 268     _body_fun = lambda x: body_fun(x)  # hashable wrapper; JAX issue #13554
    269     while_loop = jax.named_call(lax.while_loop, name="checkpointed-no-vjp")
    270     return while_loop(cond_fun, _body_fun, init_val)

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in eval_jaxpr(jaxpr, consts, propagate_source_info, *args)
    447   env: dict[Var, Any] = {}
    448   map(write, jaxpr.constvars, consts)
--> 449   map(write, jaxpr.invars, args)
    450   lu = last_used(jaxpr)
    451   for eqn in jaxpr.eqns:

ValueError: safe_map() argument 2 is shorter than argument 1

just took the code from https://github.com/patrick-kidger/diffrax/issues/407

lockwo avatar May 07 '24 20:05 lockwo

Possible this is an error in equinox, but I wasn't able to exactly replicate it without diffrax, e.g. works fine

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import equinox.internal as eqxi
import functools as ft 

def f(t, y, theta):
    return jnp.abs(jnp.sin(t)) + theta * y

_inner_loop = jax.named_call(eqxi.while_loop, name="inner-loop")
_outer_loop = jax.named_call(eqxi.while_loop, name="outer-loop")

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 10

    def inner_loop_body(state):
        t, y, theta = state
        dy = f(t, y, theta)
        return (t + 0.1, y + 0.1 * dy, theta)
    
    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5
    
    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = ft.partial(_inner_loop, kind="lax")
    outer_while_loop = ft.partial(_outer_loop, kind="lax")
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]


batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve))
pmap_fn = eqx.filter_pmap(fn)
print('shard')
_ = fn(x, y).block_until_ready()
print('pmap')
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()

lockwo avatar May 07 '24 20:05 lockwo

Issue that the error code references in stack trace: https://github.com/google/jax/issues/13554

lockwo avatar May 07 '24 20:05 lockwo

All of the above code works 100% fine in 0.4.26 btw (well the sharding is still slower, but that's a different issue)

lockwo avatar May 07 '24 20:05 lockwo

Thanks for the report! Looks like an upstream JAX bug. I've opened https://github.com/google/jax/issues/21116.

patrick-kidger avatar May 07 '24 20:05 patrick-kidger

Great, closing! 0.4.28 fixed

lockwo avatar May 10 '24 06:05 lockwo