Different runtimes for `pmap` vs `pjit`
Hi,
I have a simple script below that compares runtimes for pmap vs pjit. I expected that the runtime for pjit with full data parallelism would be the same for pmap, as they are functionally the same (shard across data, replicate the model), but it's around 1.5x-2x slower in my experiments.
# script.py
import jax
from jax.experimental.pjit import pjit
from jax.experimental.maps import Mesh
from jax.experimental import PartitionSpec as P
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--mode', type=str, choices=['pmap', 'pjit'], default='pmap')
args = parser.parse_args()
# Init data
x = np.random.randn(32, 1024).astype(np.float32)
W = np.random.randn(1024, 8).astype(np.float32)
def step(x, W):
return jax.lax.dot(x, W)
# Compute pmap or pjit functions
# Preload batch data and model parameters onto the devices as ShardedDeviceArrays
if args.mode == 'pmap':
p_step = jax.pmap(step, axis_name='batch')
x = np.reshape(x, (jax.local_device_count(), -1, x.shape[1]))
# Gets correct device order that matches pmap
devices = jax.lib.xla_bridge.get_backend().get_default_device_assignment(jax.device_count())
x = jax.device_put_sharded(list(x), devices)
W = jax.device_put_replicated(W, devices)
else:
mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(), 1), ['dp', 'mp'])
jax.experimental.maps.thread_resources.env = (
jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=())
)
p_step = pjit(step, in_axis_resources=(P('dp'), P('mp', None)), out_axis_resources=P('dp'))
# Map batch and weights to devices
p_init = pjit(lambda x, W: (x, W), in_axis_resources=(P('dp'), P('mp', None)), out_axis_resources=(P('dp'), P('mp', None)))
x, W = p_init(x, W)
# Warmup for initial compilation
p_step(x, W).block_until_ready()
# Time
iterations = 1000
avg = timeit.timeit(lambda: p_step(x, W).block_until_ready(), number=iterations) / iterations
print('Estimated Time:', avg, 'per itr')
I see the following outputs:
wilson@t1v-n-588e2d9f-w-0:~$ python3 script.py -m pmap
Estimated Time: 0.0006339213080937043 per itr
wilson@t1v-n-588e2d9f-w-0:~$ python3 script.py -m pjit
Estimated Time: 0.0009550822210730985 per itr
Note that there is some noise in the output, but even when turning up the number of samples, there's a pretty substantial difference - even more so if I run larger models (i.e. an MLP, or transformer vs just a linear layer).
All of these commands are done on a v3-8 TPU instance, with jax==0.3.14, jaxlib==0.3.14, libtpu-nightly==0.1.dev20220627.
Any help or insight would be appreciated on this! (Or, if there's an issue with any of my code)
If I remove the model parallelism component in the pjit portion of the code, i.e. only ['dp']
mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(),), ['dp'])
jax.experimental.maps.thread_resources.env = (
jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=())
)
p_step = pjit(step, in_axis_resources=(P('dp'), None), out_axis_resources=P('dp'))
# Map batch and weights to devices
p_init = pjit(lambda x, W: (x, W), in_axis_resources=(P('dp'), None), out_axis_resources=(P('dp'), None))
x, W = p_init(x, W)
It's slightly faster, but still slower than pmap:
wilson@t1v-n-588e2d9f-w-0:~$ python3 test/script.py -m pjit
Estimated Time: 0.000860201039002277 per itr
Is pjit doing some extra unnecessary computations that pmap doesn't?
Would be interested in knowing the answer to this as well, as I am also observing it in my code.
This computation is actually quite fast, so I imagine the difference might be due to the pmap dispatch path being optimized (as it's implemented in C++) compared to pjit's dispatch (written entirely in Python).
I tried running similar code on some larger models, and get similar effects. This is also done with only data parallelism (no model axis in mesh)
Code is run on a v3-8 TPU-VM with
jax==0.3.15jaxlib==0.3.15libtpu-nightly==0.1.dev20220722flax==0.5.3
Transformer 1: 1.6M params, Transformer(hidden_dim=128, num_heads=4, num_layers=8) (1.4x slower)
pmap: Estimated Time: 0.00693524660900016 per itr
pjit: Estimated Time: 0.010101222762999896 per itr
Transformer 2: 25M params, Transformer(hidden_dim=512, num_heads=8, num_layers=8) (1.2x slower)
pmap: Estimated Time: 0.015546864855000194 per itr
pjit: Estimated Time: 0.01896321659 per itr
Transformer 3: 100M params, Transformer(hidden_dim=1024, num_heads=16, num_layers=8) (1.08x slower)
pmap: Estimated Time: 0.037007930983999814 per itr
pjit: Estimated Time: 0.04010667563000016 per itr
There seems to still be a difference, with the difference decreasing as I increase model size (though still statistically significant).
I tried a different model that similar to one I'm using (on videos, frame-wise encoding / decoder with a transformer over space-time). For some reason, under an identical setup, this model is significantly slower between pmap and pjit (7x)
pmap: Estimated Time: 0.042936902346999886 per itr
pjit: Estimated Time: 0.3086913418910008 per itr
Is there any reason why model architecture could cause this large difference in performance?
The code for both models is below (script.py corresponds to the Transformer, and script2.py corresponds to the other model):
# script.py
import timeit
import argparse
import jax
from jax.experimental.pjit import pjit
from jax.experimental.maps import Mesh
from jax.experimental import PartitionSpec as P
import numpy as np
import flax.linen as nn
class Transformer(nn.Module):
hidden_dim: int
num_heads: int
num_layers: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden_dim)(x)
x = nn.LayerNorm()(x)
for _ in range(self.num_layers):
x = TransformerBlock(self.hidden_dim, self.num_heads)(x)
return x
class TransformerBlock(nn.Module):
hidden_dim: int
num_heads: int
@nn.compact
def __call__(self, x):
h = nn.LayerNorm()(x)
h = nn.SelfAttention(num_heads=self.num_heads)(x)
x = x + h
h = nn.LayerNorm()(x)
h = nn.Sequential([
nn.Dense(4 * self.hidden_dim),
nn.gelu,
nn.Dense(self.hidden_dim)
])(h)
x = x + h
return x
def print_model_size(params, name=''):
model_params_size = jax.tree_util.tree_map(lambda x: x.size, params)
total_params_size = sum(jax.tree_util.tree_flatten(model_params_size)[0])
print('model parameter count:', total_params_size)
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--mode', type=str, choices=['pmap', 'pjit'], default='pmap')
args = parser.parse_args()
# Init data
x = np.random.randn(32, 1024, 64).astype(np.float32)
model = Transformer(hidden_dim=128, num_heads=4, num_layers=8)
variables = model.init(rngs=jax.random.PRNGKey(0), x=x)
print_model_size(variables)
def step(x, variables):
return model.apply(variables, x)
# Compute pmap or pjit functions
# Preload batch data and model parameters onto the devices as ShardedDeviceArrays
if args.mode == 'pmap':
p_step = jax.pmap(step, axis_name='batch')
x = np.reshape(x, (jax.local_device_count(), -1, *x.shape[1:]))
# Gets correct device order that matches pmap
devices = jax.lib.xla_bridge.get_backend().get_default_device_assignment(jax.device_count())
x = jax.device_put_sharded(list(x), devices)
variables = jax.device_put_replicated(variables, devices)
else:
mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(),), ['dp'])
jax.experimental.maps.thread_resources.env = (
jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=())
)
p_step = pjit(step, in_axis_resources=(P('dp'), None), out_axis_resources=P('dp'))
# Map batch and weights to devices
p_init = pjit(lambda x, variables: (x, variables), in_axis_resources=(P('dp'), None), out_axis_resources=(P('dp'), None))
x, variables = p_init(x, variables)
# Warmup for initial compilation
p_step(x, variables).block_until_ready()
# Time
iterations = 1000
avg = timeit.timeit(lambda: p_step(x, variables).block_until_ready(), number=iterations) / iterations
print('Estimated Time:', avg, 'per itr')
# script2.py
from typing import Tuple, Any
import timeit
import argparse
import jax
from jax.experimental.pjit import pjit
from jax.experimental.maps import Mesh
from jax.experimental import PartitionSpec as P
import numpy as np
import flax.linen as nn
class Model(nn.Module):
enc_args: Any
tfm_args: Any
dec_args: Any
@nn.compact
def __call__(self, x):
x = jax.vmap(Encoder(**self.enc_args), 1, 1)(x)
old_shape = x.shape[1:-1]
x = x.reshape(x.shape[0], -1, x.shape[-1])
x = Transformer(**self.tfm_args)(x)
x = x.reshape(x.shape[0], *old_shape, x.shape[-1])
x = jax.vmap(Decoder(**self.dec_args), 1, 1)(x)
return x
def block(x, depth):
skip = x
if skip.shape[-1] != depth:
skip = nn.Conv(depth, [1, 1], use_bias=False)(skip)
x = nn.Sequential([
nn.GroupNorm(),
nn.elu,
nn.Conv(depth, [3, 3]),
nn.GroupNorm(),
nn.elu,
nn.Conv(depth, [3, 3])
])(x)
return skip + 0.1 * x
class Encoder(nn.Module):
depths: Tuple
blocks: int
@nn.compact
def __call__(self, x):
x = nn.Conv(self.depths[0], [3, 3])(x)
for i in range(1, len(self.depths)):
x = nn.avg_pool(x, (2, 2), strides=(2, 2))
for _ in range(self.blocks):
x = block(x, self.depths[i])
return x
class Decoder(nn.Module):
depths: Tuple
blocks: int
@nn.compact
def __call__(self, x):
for i in range(len(self.depths) - 1):
for _ in range(self.blocks):
x = block(x, self.depths[i])
x = jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
method='nearest')
x = nn.Conv(3, [3, 3])(x)
return x
class Transformer(nn.Module):
hidden_dim: int
num_heads: int
num_layers: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden_dim)(x)
x = nn.LayerNorm()(x)
for _ in range(self.num_layers):
x = TransformerBlock(self.hidden_dim, self.num_heads)(x)
return x
class TransformerBlock(nn.Module):
hidden_dim: int
num_heads: int
@nn.compact
def __call__(self, x):
h = nn.LayerNorm()(x)
h = nn.SelfAttention(num_heads=self.num_heads)(x)
x = x + h
h = nn.LayerNorm()(x)
h = nn.Sequential([
nn.Dense(4 * self.hidden_dim),
nn.gelu,
nn.Dense(self.hidden_dim)
])(h)
x = x + h
return x
def print_model_size(params, name=''):
model_params_size = jax.tree_util.tree_map(lambda x: x.size, params)
total_params_size = sum(jax.tree_util.tree_flatten(model_params_size)[0])
print('model parameter count:', total_params_size)
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--mode', type=str, choices=['pmap', 'pjit'], default='pmap')
args = parser.parse_args()
# Init data
x = np.random.randn(32, 100, 16, 16, 3).astype(np.float32)
model = Model(enc_args=dict(depths=[64, 128, 256], blocks=2),
tfm_args=dict(hidden_dim=512, num_heads=8, num_layers=8),
dec_args=dict(depths=[256, 128, 64], blocks=2))
variables = model.init(rngs=jax.random.PRNGKey(0), x=x)
print_model_size(variables)
def step(x, variables):
return model.apply(variables, x)
# Compute pmap or pjit functions
# Preload batch data and model parameters onto the devices as ShardedDeviceArrays
if args.mode == 'pmap':
p_step = jax.pmap(step, axis_name='batch')
x = np.reshape(x, (jax.local_device_count(), -1, *x.shape[1:]))
# Gets correct device order that matches pmap
devices = jax.lib.xla_bridge.get_backend().get_default_device_assignment(jax.device_count())
x = jax.device_put_sharded(list(x), devices)
variables = jax.device_put_replicated(variables, devices)
else:
mesh = Mesh(np.asarray(jax.devices(), dtype=object).reshape(jax.local_device_count(),), ['dp'])
jax.experimental.maps.thread_resources.env = (
jax.experimental.maps.ResourceEnv(physical_mesh=mesh, loops=())
)
p_step = pjit(step, in_axis_resources=(P('dp'), None), out_axis_resources=P('dp'))
# Map batch and weights to devices
p_init = pjit(lambda x, variables: (x, variables), in_axis_resources=(P('dp'), None), out_axis_resources=(P('dp'), None))
x, variables = p_init(x, variables)
# Warmup for initial compilation
p_step(x, variables).block_until_ready()
# Time
iterations = 1000
avg = timeit.timeit(lambda: p_step(x, variables).block_until_ready(), number=iterations) / iterations
print('Estimated Time:', avg, 'per itr')
I looked into it and I think it's basically that the SPMD partitioner ends up failing to propagate the data-parallel sharding through the full HLO. The HLO module does contain some while loops that might be confusing for it. SPMD partitioner is heuristic based, so it might not get everything right. This is why sometimes manual partitioning with pmap or xmap can lead to better results, since they can provide you with explicit guarantees.
Thanks for looking into it! Is there a good way to prevent this issue from happening code-wise? i.e. differently coding the architecture or enforcing certain constraints to help the partitioner partition how we'd like it to (assuming knowledge of some pre-specified data/model partitioning scheme)?
Otherwise, the best option might be to switch to using xmap for the best mix of correct data-sharding / model parallelism?
You could try sprinkling some with_sharding_constraints over your code. It might help the SPMD partitioner do the right thing. But if you know what sharding you want to get, and you actually prefer explicit control, then xmap sounds like a better way to go.
In the meantime, I've brought this up with the XLA team. I've opened an internal bug with them, and since there's nothing we can do on the JAX side to fix it, I'm going to close it. Feel free to reopen if you think there's still something more we could do. Thanks!
And also thanks so much for a detailed repro!