jax icon indicating copy to clipboard operation
jax copied to clipboard

Unexpected scaling of sharding/pmap

Open lockwo opened this issue 1 year ago • 0 comments

Description

If I were to run a constant size function and scale it linearly across more and more CPUs/GPUs (just running the same function and the same input again on another device), I would expect the time to execute to stay roughly the same since each CPU is just doing the same computation. There would be some overhead due potentially due to arrays moving around in memory/dispact/collecting but if my arrays are O(10 floats) this would be tiny I imagine. However, if I do this in jax with pmap or sharding on CPU or GPU I see that the time noticeably increases with more devices. Is this something to do with how jax manages distribution or something I am doing wrong?

Here is the example code:

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import os
import time

import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as jshard

devices_to_use = multiprocessing.cpu_count()

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(devices_to_use)
print(jax.devices())

def vectorized_solve(y0, v):
    return jax.vmap(solve, in_axes=(0, None))(y0, v)

def solve(init, var):
    return jnp.mean(init @ var)

n_per_device = 100
problem_size = 20
all_devices_to_use = [i for i in range(1, 7)]

all_total_traj = []
all_timings_aux = []
all_timings = []
y0s = []

vectorized_jit = jax.jit(vectorized_solve)

key = jax.random.PRNGKey(42)
reps = 500

for devices_to_use in all_devices_to_use:

    devices = jax.devices("gpu")[:devices_to_use]
    print(len(devices))
    device_mesh = mesh_utils.create_device_mesh((len(devices), 1), devices)
    sharding = jshard.PositionalSharding(devices).reshape((len(devices), 1))
    sharding_replicate = sharding.replicate()

    n_traj_total = len(devices) * n_per_device

    y_0 = jnp.ones((n_traj_total, problem_size))
    y_0 = y_0 + jax.random.uniform(key, shape=y_0.shape)

    y0_shard = jax.device_put(y_0, sharding).block_until_ready()
    
    var = jax.random.uniform(key, shape=(problem_size, 5))
    var_shard = jax.device_put(var, sharding_replicate).block_until_ready()
    
    _ = vectorized_jit(y0_shard, var_shard).block_until_ready()
    tot = 0
    for i in range(reps):
        start_time = time.time()
        results = vectorized_jit(y0_shard, var_shard).block_until_ready()
        end_time = time.time()
        tot += (end_time - start_time)
    tot /= reps
    all_timings.append(tot)
    
    pmap = jax.pmap(vectorized_solve, in_axes=(0, None), devices=devices)
    y_0 = y_0.reshape((len(devices), n_per_device, problem_size))
    _ = pmap(y_0, var).block_until_ready()
    
    tot = 0
    for i in range(reps):
        start_time = time.time()
        results = pmap(y_0, var).block_until_ready()
        end_time = time.time()
        tot += (end_time - start_time)
    tot /= reps
    
    all_timings_aux.append(tot)    

fig, axs = plt.subplots()
axs.plot(all_devices_to_use, all_timings, linestyle="--", marker="o", label="shard eval")
axs.plot(all_devices_to_use, all_timings_aux, linestyle="--", marker="o", label="pmap eval")
axs.legend()
axs.set_xlabel("Number of devices")
axs.set_ylabel("Total Time")
axs.set_yscale("log")
fig.tight_layout()
plt.show()

GPU: Screenshot 2024-05-23 at 7 01 09 PM

CPU: Screenshot 2024-05-23 at 7 01 26 PM

In my mind, these lines should be almost flat. The arrays getting returned are very small and the arrays getting sharded are also very small.

As an even more trivial example, I took the code from the docs on sharding. If I do 8 arrays on 8 devices it is slower than doing 4 arrays on 4 devices (by a noticeable margin). This I find to be very surprising.

Screenshot 2024-05-23 at 8 00 38 PM

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (6 total, 6 local): [cuda(id=0) cuda(id=1) ... cuda(id=4) cuda(id=5)]
process_count: 1

release='5.15.0-89-generic', version='#99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023', machine='x86_64')

lockwo avatar May 24 '24 01:05 lockwo