jax
jax copied to clipboard
Slow host to GPU transfer with `device_put`
Description
I've observed very slow host-to-gpu transfer speeds. I'm using the following benchmark script, which you may enjoy at your leisure.
import os
import argparse
import time
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_STDERR_MIN_LOG_LEVEL"] = "0"
os.environ["TF_XLA_FLAGS"] = "--vmodule=gpu_transfer_manager=3"
import jax
import numpy as np
# A100 and tpu v4 are on 64GB/s PCI busses (16GB/s/chip), H100 and TPU-v5 are on 128GB/s (I think?) (32GB/s/chip for the TPU),
PCI_BUS_SPEED = 64 * 10**9
NUM_GIGABYTES = 10
DTYPE = np.float32
jax.device_put(jax.numpy.array(0)).delete() # warmup pjit
def benchmark_transfer(pci_bandwidth=PCI_BUS_SPEED, num_gigabytes=NUM_GIGABYTES, dtype=DTYPE):
denom = np.finfo(dtype).bits // 8
candidate_array = np.arange(num_gigabytes * 1000 * 10**6 // denom, dtype=dtype)
array_size = candidate_array.itemsize * candidate_array.size
print(f"array created with size {array_size / 10**9:.2f}GB")
for _ in range(5):
t = time.time()
on_device = jax.device_put(candidate_array)
on_device.block_until_ready()
duration = time.time() - t
transfer_speed = array_size / duration
print(
f"transfer time {duration:.2f}s, transfer rate {transfer_speed / 10**9:.2f}GB/s, pci bus utilization {transfer_speed/pci_bandwidth:.1%}"
)
on_device.delete()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pci_bandwidth", type=float, default=PCI_BUS_SPEED, help="PCI bus speed in bytes per second")
parser.add_argument("--num_gigabytes", type=int, default=NUM_GIGABYTES, help="Number of gigabytes to transfer")
parser.add_argument("--dtype", type=str, default=DTYPE, help="Data type of the array")
args = parser.parse_args()
benchmark_transfer(args.pci_bandwidth, args.num_gigabytes, args.dtype)
On my A100 system it generates the following output.
2024-03-13 05:39:41.827859: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
2024-03-13 05:39:41.939174: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 18.10s, transfer rate 0.55GB/s, pci bus utilization 0.9%
Please let me know if I'm doing something wrong (likely) or if this to be expected! Many thanks for your assistance and time. Zac
Update: modified the script to rerun the transfer several times.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='zacgpu', release='6.5.0-1014-gcp', version='#14~22.04.1-Ubuntu SMP Sat Feb 10 04:57:00 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Mar 12 22:22:19 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB Off | 00000000:00:05.0 Off | 0 |
| N/A 35C P0 61W / 400W | 429MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 562519 C python 416MiB |
+---------------------------------------------------------------------------------------+
import os
import argparse
import time
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_STDERR_MIN_LOG_LEVEL"] = "0"
os.environ["TF_XLA_FLAGS"] = "--vmodule=gpu_transfer_manager=3"
import jax
import numpy as np
# A100 and tpu v4 are on 64GB/s PCI busses (16GB/s/chip), H100 and TPU-v5 are on 128GB/s (I think?) (32GB/s/chip for the TPU),
PCI_BUS_SPEED = 32 * 10**9
NUM_GIGABYTES = 10
DTYPE = np.float32
jax.device_put(jax.numpy.array(0)).delete() # warmup pjit
def benchmark_transfer(pci_bandwidth=PCI_BUS_SPEED, num_gigabytes=NUM_GIGABYTES, dtype=DTYPE):
denom = np.finfo(dtype).bits // 8
candidate_array = np.arange(num_gigabytes * 1000 * 10**6 // denom, dtype=dtype)
array_size = candidate_array.itemsize * candidate_array.size
print(f"array created with size {array_size / 10**9:.2f}GB")
def t():
t = time.time()
on_device = jax.device_put(candidate_array)
on_device.block_until_ready()
duration = time.time() - t
transfer_speed = array_size / duration
print(
f"transfer time {duration:.2f}s, transfer rate {transfer_speed / 10**9:.2f}GB/s, pci bus utilization {transfer_speed/pci_bandwidth:.1%}"
)
on_device.delete()
t()
t()
t()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pci_bandwidth", type=float, default=PCI_BUS_SPEED, help="PCI bus speed in bytes per second")
parser.add_argument("--num_gigabytes", type=int, default=NUM_GIGABYTES, help="Number of gigabytes to transfer")
parser.add_argument("--dtype", type=str, default=DTYPE, help="Data type of the array")
args = parser.parse_args()
benchmark_transfer(args.pci_bandwidth, args.num_gigabytes, args.dtype)
Can you try this script? The first call is slower then others as JAX is compiling. So you must discard it. Also, on A100, you have half the perf you expected as your number is the sum on both direction. Here is the output I'm having on V100, so the PCI ratio isn't the right one as it should be lower.
2024-03-13 16:23:50.370963: I external/xla/xla/service/dump.cc:507] HloModule dump enabled with path prefix: , suffix: before_optimizations
mp
2024-03-13 16:23:54.283633: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 11.06s, transfer rate 0.90GB/s, pci bus utilization 2.8%
transfer time 2.75s, transfer rate 3.64GB/s, pci bus utilization 11.4%
transfer time 2.76s, transfer rate 3.63GB/s, pci bus utilization 11.3%
That would still be only 22% efficient for V100.
When I run under nsys, I see the comm is split into many part. I'm not sure if this is CUDA or XLA that trigger this behavoir. Are you only looking to understand or you need to speed this up?
I think this may have regressed in 0.4.25. Try 0.4.24? (Both jax and jaxlib.)
@nouiz When I run this repeatedly there is a slight speedup
2024-03-13 18:31:36.445516: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
2024-03-13 18:31:36.555472: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:517] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 18.06s, transfer rate 0.55GB/s, pci bus utilization 0.9%
transfer time 4.96s, transfer rate 2.02GB/s, pci bus utilization 3.2%
transfer time 4.97s, transfer rate 2.01GB/s, pci bus utilization 3.1%
transfer time 4.94s, transfer rate 2.02GB/s, pci bus utilization 3.2%
transfer time 4.96s, transfer rate 2.02GB/s, pci bus utilization 3.2%
3% is still unacceptable. However, for the development cycle I will often need to load a checkpoint at the start, and if that is slow then it slows everything down. Even if it is due to compilation, Jax should not require 15 seconds to compile a simple host to device transfer.
@hawkinsp Under jax 0.4.24 it's a hair better, but nothing to write home about
2024-03-13 18:37:21.249077: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
2024-03-13 18:37:21.433903: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:469] Loaded cuDNN version 8907
array created with size 10.00GB
transfer time 16.09s, transfer rate 0.62GB/s, pci bus utilization 1.0%
transfer time 2.90s, transfer rate 3.44GB/s, pci bus utilization 5.4%
transfer time 2.89s, transfer rate 3.46GB/s, pci bus utilization 5.4%
transfer time 2.92s, transfer rate 3.43GB/s, pci bus utilization 5.4%
transfer time 2.89s, transfer rate 3.46GB/s, pci bus utilization 5.4%
https://github.com/openxla/xla/pull/10528 will revert to the 0.4.24 speeds, which is something.
@ZacCranko Did you fix the expected speed as discussed at https://github.com/google/jax/issues/20209#issuecomment-1995011832 ? I can take another look when the PR is merged.
https://github.com/openxla/xla/pull/10629 should fix this issue, although I think we can do better still (next week!)