jax
jax copied to clipboard
(Batched) Matrix Operations are much slower on GPU
Hi,
I am profiling the JAX operations for batched matrix operations on GPU and CPU. It is very surprising that GPU is much slower compared to CPU. For the code below, when B=1, the printing time is CPU= 0.00016665 vs GPU= 0.0067. Furthermore, when B=3, the printing time is CPU=0.00016355514 vs GPU=0.01804113. Therefore, GPU is not only lower than CPU, it cannot enjoy the parallel benefits for batching operations either.
For the platform, I am using 2-CPUs, GPU=P100, jax=0.2.1, jaxlib=0.1.55. I wonder if this is a known issue, and how can I fix this ? I am particularly interested in how to exploit the parallel benefits for batching operations.
import time
from jax import jit, random
import jax.numpy as jnp
B, N, N = 3, 2048, 2048
@jit
def run(iter):
rng=random.PRNGKey(iter)
a = random.normal(rng, shape=[B, N, N])
a = a @ jnp.transpose(a, [0, 2, 1])
b = jnp.linalg.cholesky(a)
return b
for _ in range(20):
run(_)
start = time.time()
run(21)
print(time.time() - start)
Hi I think I've also been hitting a similar problem of silent recompilations that eat up a lot of time. Try enabling JAX to log every time it triggers a recompilation and check that it is triggering only once for you on the first iteration. i.e.
JAX_LOG_COMPILES=1 python test.py
There was a bit in the documentation I read about "lists not being hashable, while tuples are" so potentially the way that you pass lists to random.normal for shape and to jnp.transpose for the axis order are causing those functions to always be compiled every iteration. Changing these to tuples using the (,) syntax may fix this if it is silently recompiling.
Hi thanks for the suggestion. I tried adding os.environ["JAX_LOG_COMPILES"] = "1", both CPU and GPU output the compilations signal only once. I also tried changing the lists to tuples, but it made no differences either.
There is something more strange. I replaced run(21) with for _ in range(21, 100): run(_). When B=1, the printing time is CPU=21.4955 vs GPU=0.5453. When B=3, the printing time is CPU=70.6205 vs GPU=1.4233.
I wonder if the computation has actually finished at the time it returns to python control flow. Can you try changing the invocation to be run(_).block_until_ready() as suggested here https://jax.readthedocs.io/en/latest/async_dispatch.html
You are exactly right ! After I set run(_).block_until_ready(), when B=1, the printing time is CPU=24.23 vs GPU=0.544; when B=3, the printing time is CPU=80.69 vs GPU=1.43.
A final question, is it possible to parallel these operations ? Ideally I expect the GPU should be faster by parallelisms when B>1 . However, for B=1,3,10, the time is 0.544, 1.43, 3.918, which I think not much a speed up.
Currently on GPU we do use batched algorithms for some matrix operations (e.g. Cholesky decomposition), but only for small matrices. Ultimately we're just calling Cusolver here at the moment.
On CPU more parallelism is certainly possible, although it seems your interest is on GPU.
In principle on GPU we could run multiple operations in parallel (using multiple GPU streams), although it's not something we support at the moment. It's also not obvious to me whether it would be profitable for a 2048x2048 matrix, although I don't know without trying it.
I expect support for multiple GPU streams is likely to be something we can experiment with more seriously in six or so months after some runtime and compiler changes have landed.
Thanks for your clarification !
Thank you again @hawkinsp for the explanations. I wonder if you have any suggestions for me to speed up the batch Cholesky decompositions (~1000x1000). Cholesky decompositions take quadratic memory usage, thus I reckon the GPU memory will be sufficient if the batch is not quite large. Why do you think there might not be benefits from parallelisms ?
To speed it up, it seems that pmap is one potential approach when multi-GPUs are available. If only one GPU is available to me, is the jax.numpy.cholesky already the fastest, or there are some workarounds ? Any suggestions will be very helpful, thanks.
We do not think of any war around for single GPU. Any objection that we close this issue?
For a bit of context, below is what we obtained with pytest-benchmark using an A100. cuSolver's implementation of Cholesky decomposition already parallelizes w.r.t. batch size.
A modified version of OP's snippet:
import time
import pytest
from jax import jit, random
import numpy as np
import jax.numpy as jnp
@jit
def chol(A):
return jnp.linalg.cholesky(A)
@pytest.mark.parametrize('b', [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024])
@pytest.mark.parametrize('N', [2048])
def test_chol_batch_size(benchmark, b, N):
rng=random.PRNGKey(0)
A = random.normal(rng, shape=[b, N, N])
A = A @ jnp.transpose(A, [0, 2, 1])
# warm up
chol(A).block_until_ready()
chol(A).block_until_ready()
chol(A).block_until_ready()
benchmark(lambda: chol(A).block_until_ready())
@pytest.mark.parametrize('b', [1])
@pytest.mark.parametrize('N', [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072])
def test_chol_matrix_size(benchmark, b, N):
rng=random.PRNGKey(0)
A = random.normal(rng, shape=[b, N, N])
A = A @ jnp.transpose(A, [0, 2, 1])
# warm up
chol(A).block_until_ready()
chol(A).block_until_ready()
chol(A).block_until_ready()
benchmark(lambda: chol(A).block_until_ready())
Result:
------------------------------------------- benchmark 'test_chol_batch_size': 10 tests ------------------------------------------
Name (time in ms) Median Mean StdDev IQR Iterations Rounds
---------------------------------------------------------------------------------------------------------------------------------
test_chol_batch_size[2048-1] 3.2143 (1.0) 3.8463 (1.0) 1.0606 (100.10) 2.4045 (224.23) 1 178
test_chol_batch_size[2048-2] 5.0583 (1.57) 5.0516 (1.31) 0.0344 (3.24) 0.0428 (3.99) 1 200
test_chol_batch_size[2048-4] 5.9490 (1.85) 5.9524 (1.55) 0.0141 (1.33) 0.0182 (1.69) 1 170
test_chol_batch_size[2048-8] 7.6304 (2.37) 7.6320 (1.98) 0.0106 (1.0) 0.0107 (1.0) 1 132
test_chol_batch_size[2048-16] 10.7962 (3.36) 10.8420 (2.82) 0.2949 (27.83) 0.1056 (9.84) 1 94
test_chol_batch_size[2048-32] 17.9783 (5.59) 18.0078 (4.68) 0.2234 (21.08) 0.3887 (36.25) 1 57
test_chol_batch_size[2048-64] 33.2420 (10.34) 33.2182 (8.64) 0.3663 (34.57) 0.4239 (39.53) 1 31
test_chol_batch_size[2048-128] 64.9609 (20.21) 65.0094 (16.90) 0.2303 (21.74) 0.2981 (27.80) 1 16
test_chol_batch_size[2048-256] 128.6795 (40.03) 128.7425 (33.47) 0.3383 (31.93) 0.4527 (42.22) 1 8
test_chol_batch_size[2048-512] 262.0506 (81.53) 261.9398 (68.10) 2.7970 (263.98) 3.1338 (292.24) 1 5
---------------------------------------------------------------------------------------------------------------------------------
------------------------------------------- benchmark 'test_chol_matrix_size': 6 tests ------------------------------------------
Name (time in ms) Median Mean StdDev IQR Iterations Rounds
---------------------------------------------------------------------------------------------------------------------------------
test_chol_matrix_size[1024-1] 1.4406 (1.0) 1.7101 (1.0) 0.4613 (18.81) 0.9990 (196.38) 1 361
test_chol_matrix_size[2048-1] 3.2106 (2.23) 3.2121 (1.88) 0.0245 (1.0) 0.0051 (1.0) 1 312
test_chol_matrix_size[4096-1] 7.7203 (5.36) 7.7340 (4.52) 0.0312 (1.27) 0.0433 (8.51) 1 130
test_chol_matrix_size[8192-1] 24.5720 (17.06) 24.6321 (14.40) 0.2355 (9.61) 0.0491 (9.65) 1 41
test_chol_matrix_size[16384-1] 125.7637 (87.30) 125.7540 (73.54) 0.5370 (21.90) 0.8374 (164.63) 1 8
test_chol_matrix_size[32768-1] 830.0251 (576.15) 830.3184 (485.54) 1.0143 (41.36) 1.7149 (337.12) 1 5
---------------------------------------------------------------------------------------------------------------------------------