Question: `filter_shard` with PartitionSpecs, or other ways to ensure batched output, as in `vmap`.
Hi again! I have yet another question, this time to do with sharding.
Essentially, I'm trying to shard to improve performance, but I would like the output to be the same as that of vmap.
Background: I currently vmap to parallelise across my data, and I want to get a single value for each input array. (I do parameter estimation for ODEs across individuals.) I noticed that this approach only ever uses about 20 % of my CPU, and confirmed that it essentially runs on a single CPU core. I'd like to change that to multicore CPU, and then scale to GPUs.
What I did so far:
- I tried following the data parallelism example, but did not get it to work. It seems to essentially fit the average time series and it returns a single value.
- I took a look at the implementation of
eqx.filter_vmapand I'm wondering if this could be broadened to allow for something like thePartitionSpecsoptions JAX has, or if I should use decorators on top ofeqx.filter_shard, similar to how it is done in the JAX shard map tutorial, or not use equinox at all. - Alternatively, a naive approach would be to try a
pmap-of-vmaps, but I'd like to avoid that since I'm not sure whether pmap will be deprecated.
Here is some code that replicates the behavior I am talking about. Apologies for length, I tried to condense it as much as I could.
import pytest
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard
import equinox as eqx
import diffrax as dfx
import optimistix as optx
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
class ToyModel(eqx.Module):
"""Toy model that integrates an ODE."""
_term: dfx.ODETerm
def __init__(self, ode_model):
self._term = dfx.ODETerm(ode_model)
def __call__(self, param):
sol = dfx.diffeqsolve(self._term, dfx.Tsit5(), 0, 30, 0.01, jnp.array([10.]), args=param,
saveat=dfx.SaveAt(ts=jnp.linspace(0, 30, 50)), adjoint=dfx.DirectAdjoint(), max_steps=16**4)
return jnp.transpose(sol.ys) # To get the shape I'm used to
class Estimator(eqx.Module):
_solver: optx.AbstractLeastSquaresSolver
def __init__(self, solver):
self._solver = solver
@eqx.filter_jit(donate="none") # I hope that keeps it simpler for now
def __call__(self, param, model, data):
args = (model, data)
def residuals(param, args):
model, data = args
fit = model(param)
return data - fit
solution = optx.least_squares(residuals, self._solver, param, args=args)
return solution.value
# Create the model
def dydt(t, y, k):
return -k * y
model = ToyModel(dydt)
# Generate fake data
rates = jnp.arange(0.1, 0.9, 0.1) # Makes array with eight entries
ys = eqx.filter_vmap(model)(rates)
# Create the estimator
estimator = Estimator(optx.LevenbergMarquardt(atol=1e-06, rtol=1e-03))
# Solution: current approach (uses one core of my CPU even for very large data sets)
k0 = jnp.array([0.5]*8)
fitted_ks = eqx.filter_jit(eqx.filter_vmap(estimator, in_axes=(0, None, 0)))(k0, model, ys) # Current approach
# Try something akin to the equinox data parallelism example
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()
sharded_estimator = eqx.filter_shard(estimator, replicated)
sharded_inputs = eqx.filter_shard((k0, model, ys), replicated)
with pytest.raises(ValueError):
# This causes problems in diffrax: complaint about PyTree structure
fitted_ks_with_sharding = sharded_estimator(*sharded_inputs)
scalar_k0 = 0.5
sharded_inputs_with_scalar_k0 = eqx.filter_shard((scalar_k0, model, ys), replicated)
sharding_solutions = sharded_estimator(*sharded_inputs_with_scalar_k0) # Seems to fit the average
sharding_solutions
PS: bit random thought, will try later: should I shard the residuals function instead?
Ah, I think this is something JAX still needs better docs for.
In response to the various points you raise:
-
filter_vmapdefinitely shouldn't do anything to do with sharding. We use vmap to describe the logical computation, and sharding to describe the physical computation -- these two things are deliberately held separate. (Combining them into one is pretty much what{filter_}pmapdoes, and there's a reason that JAX is moving away from that!) -
The recommended approach to this these days is to write out your logical computation as if it was all running on one big device -- using vmaps and whatever else necessary to express your computation. Then, when you call it, pass in sharded arrays to describe where you would like the computation to be performed physically.
As such, best practice is pretty much what's currently shown in the Equinox parallelism tutorial. You say that you seem to be fitting the average time series -- this sounds like a bug in your logical computation, and as such would be unrelated to sharding (which is the physical computation). So I think tracking that down is probably the main thing you need to do.
-
In terms of your example code: I think you've forgotten to vmap your sharded estimator. I only see you doing a vmap in the unsharded case.
-
Personally I've always found JAX's multi-CPU parallelism to be fairly dodgy. This has never been very well supported in JAX's backend, and indeed sharding is mostly used to express multi-GPU or multi-TPU parallelism. Logic bugs aside, it's entirely possible that you'll see no speedup / that it only uses one core.
Hi Patrick,
thank you for your thoughtful response! I always learn so much from your explanations.
I'm trying to whittle it down to a smaller thing. So far I noticed that it does not raise an error if I do not shard the inputs, but only the vmapped estimator. Then the output is what I would expect. My toy model is too small to check if it is really running on all cores in that case, though.
I'm trying to get it to work on the non-toy example, but that adds quite a bit of complexity and I'm not there yet :)