Unexpected scaling of sharding/pmap
Description
If I were to run a constant size function and scale it linearly across more and more CPUs/GPUs (just running the same function and the same input again on another device), I would expect the time to execute to stay roughly the same since each CPU is just doing the same computation. There would be some overhead due potentially due to arrays moving around in memory/dispact/collecting but if my arrays are O(10 floats) this would be tiny I imagine. However, if I do this in jax with pmap or sharding on CPU or GPU I see that the time noticeably increases with more devices. Is this something to do with how jax manages distribution or something I am doing wrong?
Here is the example code:
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import os
import time
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard
devices_to_use = multiprocessing.cpu_count()
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(devices_to_use)
print(jax.devices())
def vectorized_solve(y0, v):
return jax.vmap(solve, in_axes=(0, None))(y0, v)
def solve(init, var):
return jnp.mean(init @ var)
n_per_device = 100
problem_size = 20
all_devices_to_use = [i for i in range(1, 7)]
all_total_traj = []
all_timings_aux = []
all_timings = []
y0s = []
vectorized_jit = jax.jit(vectorized_solve)
key = jax.random.PRNGKey(42)
reps = 500
for devices_to_use in all_devices_to_use:
devices = jax.devices("gpu")[:devices_to_use]
print(len(devices))
device_mesh = mesh_utils.create_device_mesh((len(devices), 1), devices)
sharding = jshard.PositionalSharding(devices).reshape((len(devices), 1))
sharding_replicate = sharding.replicate()
n_traj_total = len(devices) * n_per_device
y_0 = jnp.ones((n_traj_total, problem_size))
y_0 = y_0 + jax.random.uniform(key, shape=y_0.shape)
y0_shard = jax.device_put(y_0, sharding).block_until_ready()
var = jax.random.uniform(key, shape=(problem_size, 5))
var_shard = jax.device_put(var, sharding_replicate).block_until_ready()
_ = vectorized_jit(y0_shard, var_shard).block_until_ready()
tot = 0
for i in range(reps):
start_time = time.time()
results = vectorized_jit(y0_shard, var_shard).block_until_ready()
end_time = time.time()
tot += (end_time - start_time)
tot /= reps
all_timings.append(tot)
pmap = jax.pmap(vectorized_solve, in_axes=(0, None), devices=devices)
y_0 = y_0.reshape((len(devices), n_per_device, problem_size))
_ = pmap(y_0, var).block_until_ready()
tot = 0
for i in range(reps):
start_time = time.time()
results = pmap(y_0, var).block_until_ready()
end_time = time.time()
tot += (end_time - start_time)
tot /= reps
all_timings_aux.append(tot)
fig, axs = plt.subplots()
axs.plot(all_devices_to_use, all_timings, linestyle="--", marker="o", label="shard eval")
axs.plot(all_devices_to_use, all_timings_aux, linestyle="--", marker="o", label="pmap eval")
axs.legend()
axs.set_xlabel("Number of devices")
axs.set_ylabel("Total Time")
axs.set_yscale("log")
fig.tight_layout()
plt.show()
GPU:
CPU:
In my mind, these lines should be almost flat. The arrays getting returned are very small and the arrays getting sharded are also very small.
As an even more trivial example, I took the code from the docs on sharding. If I do 8 arrays on 8 devices it is slower than doing 4 arrays on 4 devices (by a noticeable margin). This I find to be very surprising.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (6 total, 6 local): [cuda(id=0) cuda(id=1) ... cuda(id=4) cuda(id=5)]
process_count: 1
release='5.15.0-89-generic', version='#99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023', machine='x86_64')