equinox
equinox copied to clipboard
Performance of `filter_pmap`
Hi,
I have been trying to optimize a training pipeline for TPU, and noticed an unusual performance degradation when using eqx.filter_pmap
, compared with jax.pmap
For example, I am using the following dummy function, just to test the maximum throughput from CPU -> TPU.
def step(model, x, opt_state, *, key: jax.random.PRNGKey):
# ... do nothing ...
loss = 0.0
return loss, model, opt_state
Training Loop 1 (with jax.pmap, discarding static model fields) (~25it/s):
step_parallel = jax.pmap(step, axis_name='devices')
model_weights, _ = eqx.partition(model, eqx.is_array)
for x in tqdm(loader):
x = shard_over_devices(x)
loss, model_weights, opt_state = step_parallel(model_weights, x, opt_state)
Training Loop 2 (with eqx.filter_pmap on full model) (~8it/s):
step_parallel = eqx.filter_pmap(step, axis_name='devices')
for x in tqdm(loader):
x = shard_over_devices(x)
loss, model, opt_state = step_parallel(model, x, opt_state)
Update:
If I refactor my step function into:
def step(params, static, ...):
model = eqx.combine(params, static)
...
step_parallel = jax.pmap(step, static_broadcasted_argnums=1, axis_name='devices')
I still get the faster speed of 25it/s. As soon as I switch to eqx.filter_pmap it takes a performance hit, and equinox appears to be the bottleneck š¤ Do I need to tell eqx.filter_pmap which args to static?
Another Update:
I thought this might somehow be a quirk of asynchronous dispatch, but after building out the whole pipeline I do indeed get a speedup if I use jax.pmap
and handle the filtering myself. If I have time I'll try to write a minimum-reproducible example.
So filter_pmap
definitely has a small amount of extra overhead, above that provided by manually combining jax.pmap
and eqx.{partition,combine}
. This is just because filter_pmap
is more general, so it needs to do a bit more work to resolve all the options that have been passed to it.
My experience so far has been that this doesn't matter in practice: when doing distributed training then these overheads are dwarfed by the actual computation. At least in my own use cases.
Of course, your particular example is the opposite: the computation is a no-op, so you're maximising the extent to which overhead can affect you.
Are you finding that this overhead is affecting your own practical use-cases?
Hi, thanks for getting back to me! (Feel like Iām opening a new question everyday at this point š ).
I used the no-op as an illustrative example and to diagnose the perf limits. But I am finding this matters for my actual experiment as well.
My model is a fairly standard transformer, training on a TPU v3-8. For certain configurations, it seems like I can get almost a 2x speed up by using jax.pmap. That is, I can feed the transformer 20 batches a second instead of 10 in the train loop.
Interesting! Yep, that's definitely not desired -- it sounds like your hardware might just be beefy enough that you're able to do the whole forward evaluation in a time that's comparable to the overhead.
So the forward pass through filter_pmap
is here:
https://github.com/patrick-kidger/equinox/blob/5c6a43fa7ff6fc894f5e3e6419b3c0fdcae04bab/equinox/vmap_pmap.py#L455
Looking at this, the two obvious sources of overhead are the signature-binding here and the logic needed to handle the choice of possible mapped-over dimensions here. The latter already has some measure of caching here.
If you're comfortable playing with the internals of Equinox a bit, I'd quite like to know how your performance benchmarks change as we bake-in fast-paths for your particular workload.
Assuming that your annotated function: (a) has no keyword-only arguments; (b) is not called with keyword arguments; (c) you are using filter_pmap
with the default filtering choices (in particular having the in and out axes be 0
), then can you try:
-
Deleting this line, and replacing
bound.args
with justargs
below. (Double-check thatbound.kwargs
is just{}
, and then just hardcode it this value.) -
Hardcoding
max_out_size = 1
here, and deleting the call to_get_max_out_size
. -
Checking that the cache here is actually used, and that each call to
eqx.filter_pmap
isn't missing the cache for some reason. (lru_cache
has some methods you can call to examine this.)
In each case testing each change independently of the others.
Oh also - trying a profiler like py-spy may be interesting as well. (Rather than just me speculating about which bit is slow.)
Returning to this: Equinox version 0.10.0 should include some performance improvements for eqx.filter_pmap
. I'd be curious to know this helps out here.