tensorstore icon indicating copy to clipboard operation
tensorstore copied to clipboard

Read speeds decrease 2x when reading with fewer processes

Open heiner opened this issue 5 months ago • 10 comments

The issue

Given a specific checkpoint, load it in two different settings:

  1. Load it with 64 nodes, 512 GPUs, 512 processes (1 GPU / process).
  2. Load it with 64 nodes, 512 GPUs, 64 processes (8 GPUs / process).

What I observe:

  1. Using 512 processes, reading takes ~20 seconds.
  2. Using 64 processes, reading takes ~40 seconds (2x).

The checkpoint in question is also written with 512 processes (see below for repro). Except for the number of processes, nothing else changes (sharding etc. stays the same).

To reproduce.

Download this file and run it in a context with 64 nodes, 8 GPUs each. Make sure hostfile has the hostnames of the 64 nodes. (mpirun isn't essential here, it's just a way to spawn these processes.)

To create the checkpoint:

mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 512 -npernode 8 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'

To load the checkpoint with 512 processes:

mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 512 -npernode 8 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'

This takes ~20 sec for me.

To load the checkpoint with 64 processes:

mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 64 -npernode 1 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'

This takes ~40 sec for me.

The issue doesn't seem to be in Orbax because the same happens with a plain jax.experimental.serialization.async_deserialize.

heiner avatar Sep 06 '24 00:09 heiner