jax icon indicating copy to clipboard operation
jax copied to clipboard

Slow host to GPU transfer with `device_put`

Open ZacCranko opened this issue 11 months ago • 6 comments

Description

I've observed very slow host-to-gpu transfer speeds. I'm using the following benchmark script, which you may enjoy at your leisure.

import os
import argparse
import time

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_STDERR_MIN_LOG_LEVEL"] = "0"
os.environ["TF_XLA_FLAGS"] = "--vmodule=gpu_transfer_manager=3"

import jax

import numpy as np

# A100 and tpu v4 are on 64GB/s PCI busses (16GB/s/chip), H100 and TPU-v5  are on 128GB/s (I think?) (32GB/s/chip for the TPU),
PCI_BUS_SPEED = 64 * 10**9
NUM_GIGABYTES = 10
DTYPE = np.float32

jax.device_put(jax.numpy.array(0)).delete()  # warmup pjit


def benchmark_transfer(pci_bandwidth=PCI_BUS_SPEED, num_gigabytes=NUM_GIGABYTES, dtype=DTYPE):
    denom = np.finfo(dtype).bits // 8
    candidate_array = np.arange(num_gigabytes * 1000 * 10**6 // denom, dtype=dtype)
    array_size = candidate_array.itemsize * candidate_array.size
    print(f"array created with size {array_size / 10**9:.2f}GB")

    for _ in range(5):
        t = time.time()
        on_device = jax.device_put(candidate_array)
        on_device.block_until_ready()
        duration = time.time() - t

        transfer_speed = array_size / duration
        print(
            f"transfer time {duration:.2f}s, transfer rate {transfer_speed / 10**9:.2f}GB/s, pci bus utilization {transfer_speed/pci_bandwidth:.1%}"
        )
        on_device.delete()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pci_bandwidth", type=float, default=PCI_BUS_SPEED, help="PCI bus speed in bytes per second")
    parser.add_argument("--num_gigabytes", type=int, default=NUM_GIGABYTES, help="Number of gigabytes to transfer")
    parser.add_argument("--dtype", type=str, default=DTYPE, help="Data type of the array")
    args = parser.parse_args()

    benchmark_transfer(args.pci_bandwidth, args.num_gigabytes, args.dtype)

On my A100 system it generates the following output.

2024-03-13 05:39:41.827859: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
2024-03-13 05:39:41.939174: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 18.10s, transfer rate 0.55GB/s, pci bus utilization 0.9%

Please let me know if I'm doing something wrong (likely) or if this to be expected! Many thanks for your assistance and time. Zac

Update: modified the script to rerun the transfer several times.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='zacgpu', release='6.5.0-1014-gcp', version='#14~22.04.1-Ubuntu SMP Sat Feb 10 04:57:00 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Mar 12 22:22:19 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          Off | 00000000:00:05.0 Off |                    0 |
| N/A   35C    P0              61W / 400W |    429MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    562519      C   python                                      416MiB |
+---------------------------------------------------------------------------------------+

ZacCranko avatar Mar 12 '24 22:03 ZacCranko

import os
import argparse
import time

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_STDERR_MIN_LOG_LEVEL"] = "0"
os.environ["TF_XLA_FLAGS"] = "--vmodule=gpu_transfer_manager=3"

import jax

import numpy as np

# A100 and tpu v4 are on 64GB/s PCI busses (16GB/s/chip), H100 and TPU-v5  are on 128GB/s (I think?) (32GB/s/chip for the TPU),
PCI_BUS_SPEED = 32 * 10**9
NUM_GIGABYTES = 10
DTYPE = np.float32

jax.device_put(jax.numpy.array(0)).delete()  # warmup pjit


def benchmark_transfer(pci_bandwidth=PCI_BUS_SPEED, num_gigabytes=NUM_GIGABYTES, dtype=DTYPE):
    denom = np.finfo(dtype).bits // 8
    candidate_array = np.arange(num_gigabytes * 1000 * 10**6 // denom, dtype=dtype)
    array_size = candidate_array.itemsize * candidate_array.size
    print(f"array created with size {array_size / 10**9:.2f}GB")
    
    def t():
        t = time.time()
        on_device = jax.device_put(candidate_array)
        on_device.block_until_ready()
        duration = time.time() - t


        transfer_speed = array_size / duration
        print(
        f"transfer time {duration:.2f}s, transfer rate {transfer_speed / 10**9:.2f}GB/s, pci bus utilization {transfer_speed/pci_bandwidth:.1%}"
        )
        on_device.delete()
    t()
    t()
    t()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pci_bandwidth", type=float, default=PCI_BUS_SPEED, help="PCI bus speed in bytes per second")
    parser.add_argument("--num_gigabytes", type=int, default=NUM_GIGABYTES, help="Number of gigabytes to transfer")
    parser.add_argument("--dtype", type=str, default=DTYPE, help="Data type of the array")
    args = parser.parse_args()

    benchmark_transfer(args.pci_bandwidth, args.num_gigabytes, args.dtype)

Can you try this script? The first call is slower then others as JAX is compiling. So you must discard it. Also, on A100, you have half the perf you expected as your number is the sum on both direction. Here is the output I'm having on V100, so the PCI ratio isn't the right one as it should be lower.

2024-03-13 16:23:50.370963: I external/xla/xla/service/dump.cc:507] HloModule dump enabled with path prefix: , suffix: before_optimizations
mp
2024-03-13 16:23:54.283633: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 11.06s, transfer rate 0.90GB/s, pci bus utilization 2.8%
transfer time 2.75s, transfer rate 3.64GB/s, pci bus utilization 11.4%
transfer time 2.76s, transfer rate 3.63GB/s, pci bus utilization 11.3%

That would still be only 22% efficient for V100.

When I run under nsys, I see the comm is split into many part. I'm not sure if this is CUDA or XLA that trigger this behavoir. Are you only looking to understand or you need to speed this up?

nouiz avatar Mar 13 '24 17:03 nouiz

I think this may have regressed in 0.4.25. Try 0.4.24? (Both jax and jaxlib.)

hawkinsp avatar Mar 13 '24 18:03 hawkinsp

@nouiz When I run this repeatedly there is a slight speedup

2024-03-13 18:31:36.445516: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
2024-03-13 18:31:36.555472: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 18.06s, transfer rate 0.55GB/s, pci bus utilization 0.9%
transfer time 4.96s, transfer rate 2.02GB/s, pci bus utilization 3.2%
transfer time 4.97s, transfer rate 2.01GB/s, pci bus utilization 3.1%
transfer time 4.94s, transfer rate 2.02GB/s, pci bus utilization 3.2%
transfer time 4.96s, transfer rate 2.02GB/s, pci bus utilization 3.2%

3% is still unacceptable. However, for the development cycle I will often need to load a checkpoint at the start, and if that is slow then it slows everything down. Even if it is due to compilation, Jax should not require 15 seconds to compile a simple host to device transfer.

@hawkinsp Under jax 0.4.24 it's a hair better, but nothing to write home about

2024-03-13 18:37:21.249077: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
2024-03-13 18:37:21.433903: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:469] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 16.09s, transfer rate 0.62GB/s, pci bus utilization 1.0%
transfer time 2.90s, transfer rate 3.44GB/s, pci bus utilization 5.4%
transfer time 2.89s, transfer rate 3.46GB/s, pci bus utilization 5.4%
transfer time 2.92s, transfer rate 3.43GB/s, pci bus utilization 5.4%
transfer time 2.89s, transfer rate 3.46GB/s, pci bus utilization 5.4%

ZacCranko avatar Mar 13 '24 18:03 ZacCranko

https://github.com/openxla/xla/pull/10528 will revert to the 0.4.24 speeds, which is something.

hawkinsp avatar Mar 13 '24 18:03 hawkinsp

@ZacCranko Did you fix the expected speed as discussed at https://github.com/google/jax/issues/20209#issuecomment-1995011832 ? I can take another look when the PR is merged.

nouiz avatar Mar 13 '24 19:03 nouiz

https://github.com/openxla/xla/pull/10629 should fix this issue, although I think we can do better still (next week!)

hawkinsp avatar Mar 15 '24 23:03 hawkinsp