equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Question: `filter_shard` with PartitionSpecs, or other ways to ensure batched output, as in `vmap`.

Open johannahaffner opened this issue 1 year ago • 3 comments

Hi again! I have yet another question, this time to do with sharding.

Essentially, I'm trying to shard to improve performance, but I would like the output to be the same as that of vmap.

Background: I currently vmap to parallelise across my data, and I want to get a single value for each input array. (I do parameter estimation for ODEs across individuals.) I noticed that this approach only ever uses about 20 % of my CPU, and confirmed that it essentially runs on a single CPU core. I'd like to change that to multicore CPU, and then scale to GPUs.

What I did so far:

  • I tried following the data parallelism example, but did not get it to work. It seems to essentially fit the average time series and it returns a single value.
  • I took a look at the implementation of eqx.filter_vmap and I'm wondering if this could be broadened to allow for something like the PartitionSpecs options JAX has, or if I should use decorators on top of eqx.filter_shard, similar to how it is done in the JAX shard map tutorial, or not use equinox at all.
  • Alternatively, a naive approach would be to try a pmap-of-vmaps, but I'd like to avoid that since I'm not sure whether pmap will be deprecated.

Here is some code that replicates the behavior I am talking about. Apologies for length, I tried to condense it as much as I could.

import pytest

import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard

import equinox as eqx
import diffrax as dfx
import optimistix as optx

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices

class ToyModel(eqx.Module):
    """Toy model that integrates an ODE."""
    _term: dfx.ODETerm
    def __init__(self, ode_model):
        self._term = dfx.ODETerm(ode_model)

    def __call__(self, param):
        sol = dfx.diffeqsolve(self._term, dfx.Tsit5(), 0, 30, 0.01, jnp.array([10.]), args=param, 
            saveat=dfx.SaveAt(ts=jnp.linspace(0, 30, 50)), adjoint=dfx.DirectAdjoint(), max_steps=16**4)
        return jnp.transpose(sol.ys)  # To get the shape I'm used to

class Estimator(eqx.Module):
    _solver: optx.AbstractLeastSquaresSolver
    def __init__(self, solver):
        self._solver = solver

    @eqx.filter_jit(donate="none")  # I hope that keeps it simpler for now
    def __call__(self, param, model, data):
        args = (model, data)
        def residuals(param, args):
            model, data = args
            fit = model(param)
            return data - fit
        solution = optx.least_squares(residuals, self._solver, param, args=args)
        return solution.value

# Create the model
def dydt(t, y, k): 
    return -k * y 
model = ToyModel(dydt)

# Generate fake data
rates = jnp.arange(0.1, 0.9, 0.1)  # Makes array with eight entries
ys = eqx.filter_vmap(model)(rates)

# Create the estimator
estimator = Estimator(optx.LevenbergMarquardt(atol=1e-06, rtol=1e-03))

# Solution: current approach (uses one core of my CPU even for very large data sets)
k0 = jnp.array([0.5]*8)
fitted_ks = eqx.filter_jit(eqx.filter_vmap(estimator, in_axes=(0, None, 0)))(k0, model, ys)  # Current approach

# Try something akin to the equinox data parallelism example
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()

sharded_estimator = eqx.filter_shard(estimator, replicated)
sharded_inputs = eqx.filter_shard((k0, model, ys), replicated)

with pytest.raises(ValueError):
    # This causes problems in diffrax: complaint about PyTree structure
    fitted_ks_with_sharding = sharded_estimator(*sharded_inputs)  

scalar_k0 = 0.5
sharded_inputs_with_scalar_k0 = eqx.filter_shard((scalar_k0, model, ys), replicated)
sharding_solutions = sharded_estimator(*sharded_inputs_with_scalar_k0)  # Seems to fit the average

sharding_solutions

johannahaffner avatar Jun 14 '24 06:06 johannahaffner

PS: bit random thought, will try later: should I shard the residuals function instead?

johannahaffner avatar Jun 14 '24 08:06 johannahaffner

Ah, I think this is something JAX still needs better docs for.

In response to the various points you raise:

  • filter_vmap definitely shouldn't do anything to do with sharding. We use vmap to describe the logical computation, and sharding to describe the physical computation -- these two things are deliberately held separate. (Combining them into one is pretty much what {filter_}pmap does, and there's a reason that JAX is moving away from that!)

  • The recommended approach to this these days is to write out your logical computation as if it was all running on one big device -- using vmaps and whatever else necessary to express your computation. Then, when you call it, pass in sharded arrays to describe where you would like the computation to be performed physically.

    As such, best practice is pretty much what's currently shown in the Equinox parallelism tutorial. You say that you seem to be fitting the average time series -- this sounds like a bug in your logical computation, and as such would be unrelated to sharding (which is the physical computation). So I think tracking that down is probably the main thing you need to do.

  • In terms of your example code: I think you've forgotten to vmap your sharded estimator. I only see you doing a vmap in the unsharded case.

  • Personally I've always found JAX's multi-CPU parallelism to be fairly dodgy. This has never been very well supported in JAX's backend, and indeed sharding is mostly used to express multi-GPU or multi-TPU parallelism. Logic bugs aside, it's entirely possible that you'll see no speedup / that it only uses one core.

patrick-kidger avatar Jun 15 '24 11:06 patrick-kidger

Hi Patrick,

thank you for your thoughtful response! I always learn so much from your explanations.

I'm trying to whittle it down to a smaller thing. So far I noticed that it does not raise an error if I do not shard the inputs, but only the vmapped estimator. Then the output is what I would expect. My toy model is too small to check if it is really running on all cores in that case, though.

I'm trying to get it to work on the non-toy example, but that adds quite a bit of complexity and I'm not there yet :)

johannahaffner avatar Jun 15 '24 18:06 johannahaffner