equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Performance of `filter_pmap`

Open angusturner opened this issue 1 year ago ā€¢ 5 comments

Hi,

I have been trying to optimize a training pipeline for TPU, and noticed an unusual performance degradation when using eqx.filter_pmap, compared with jax.pmap

For example, I am using the following dummy function, just to test the maximum throughput from CPU -> TPU.

def step(model, x, opt_state, *, key: jax.random.PRNGKey):
    # ... do nothing ...
    loss = 0.0
    return loss, model, opt_state

Training Loop 1 (with jax.pmap, discarding static model fields) (~25it/s):

step_parallel = jax.pmap(step, axis_name='devices')
model_weights, _ = eqx.partition(model, eqx.is_array)
for x in tqdm(loader):
    x = shard_over_devices(x)
    loss, model_weights, opt_state = step_parallel(model_weights, x, opt_state)

Training Loop 2 (with eqx.filter_pmap on full model) (~8it/s):

step_parallel = eqx.filter_pmap(step, axis_name='devices')
for x in tqdm(loader):
    x = shard_over_devices(x)
    loss, model, opt_state = step_parallel(model, x, opt_state)

Update:

If I refactor my step function into:

def step(params, static, ...):
    model = eqx.combine(params, static)
    ...
step_parallel = jax.pmap(step, static_broadcasted_argnums=1, axis_name='devices')

I still get the faster speed of 25it/s. As soon as I switch to eqx.filter_pmap it takes a performance hit, and equinox appears to be the bottleneck šŸ¤” Do I need to tell eqx.filter_pmap which args to static?

Another Update:

I thought this might somehow be a quirk of asynchronous dispatch, but after building out the whole pipeline I do indeed get a speedup if I use jax.pmap and handle the filtering myself. If I have time I'll try to write a minimum-reproducible example.

angusturner avatar Sep 07 '22 04:09 angusturner

So filter_pmap definitely has a small amount of extra overhead, above that provided by manually combining jax.pmap and eqx.{partition,combine}. This is just because filter_pmap is more general, so it needs to do a bit more work to resolve all the options that have been passed to it.

My experience so far has been that this doesn't matter in practice: when doing distributed training then these overheads are dwarfed by the actual computation. At least in my own use cases.

Of course, your particular example is the opposite: the computation is a no-op, so you're maximising the extent to which overhead can affect you.

Are you finding that this overhead is affecting your own practical use-cases?

patrick-kidger avatar Sep 07 '22 17:09 patrick-kidger

Hi, thanks for getting back to me! (Feel like Iā€™m opening a new question everyday at this point šŸ˜…).

I used the no-op as an illustrative example and to diagnose the perf limits. But I am finding this matters for my actual experiment as well.

My model is a fairly standard transformer, training on a TPU v3-8. For certain configurations, it seems like I can get almost a 2x speed up by using jax.pmap. That is, I can feed the transformer 20 batches a second instead of 10 in the train loop.

angusturner avatar Sep 07 '22 22:09 angusturner

Interesting! Yep, that's definitely not desired -- it sounds like your hardware might just be beefy enough that you're able to do the whole forward evaluation in a time that's comparable to the overhead.

So the forward pass through filter_pmap is here:

https://github.com/patrick-kidger/equinox/blob/5c6a43fa7ff6fc894f5e3e6419b3c0fdcae04bab/equinox/vmap_pmap.py#L455

Looking at this, the two obvious sources of overhead are the signature-binding here and the logic needed to handle the choice of possible mapped-over dimensions here. The latter already has some measure of caching here.

If you're comfortable playing with the internals of Equinox a bit, I'd quite like to know how your performance benchmarks change as we bake-in fast-paths for your particular workload.

Assuming that your annotated function: (a) has no keyword-only arguments; (b) is not called with keyword arguments; (c) you are using filter_pmap with the default filtering choices (in particular having the in and out axes be 0), then can you try:

  • Deleting this line, and replacing bound.args with just args below. (Double-check that bound.kwargs is just {}, and then just hardcode it this value.)

  • Hardcoding max_out_size = 1 here, and deleting the call to _get_max_out_size.

  • Checking that the cache here is actually used, and that each call to eqx.filter_pmap isn't missing the cache for some reason. (lru_cache has some methods you can call to examine this.)

In each case testing each change independently of the others.

patrick-kidger avatar Sep 08 '22 00:09 patrick-kidger

Oh also - trying a profiler like py-spy may be interesting as well. (Rather than just me speculating about which bit is slow.)

patrick-kidger avatar Sep 08 '22 06:09 patrick-kidger

Returning to this: Equinox version 0.10.0 should include some performance improvements for eqx.filter_pmap. I'd be curious to know this helps out here.

patrick-kidger avatar Feb 28 '23 03:02 patrick-kidger