jax icon indicating copy to clipboard operation
jax copied to clipboard

(Batched) Matrix Operations are much slower on GPU

Open ssydasheng opened this issue 4 years ago • 7 comments

Hi,

I am profiling the JAX operations for batched matrix operations on GPU and CPU. It is very surprising that GPU is much slower compared to CPU. For the code below, when B=1, the printing time is CPU= 0.00016665 vs GPU= 0.0067. Furthermore, when B=3, the printing time is CPU=0.00016355514 vs GPU=0.01804113. Therefore, GPU is not only lower than CPU, it cannot enjoy the parallel benefits for batching operations either.

For the platform, I am using 2-CPUs, GPU=P100, jax=0.2.1, jaxlib=0.1.55. I wonder if this is a known issue, and how can I fix this ? I am particularly interested in how to exploit the parallel benefits for batching operations.

import time
from jax import jit, random
import jax.numpy as jnp

B, N, N = 3, 2048, 2048

@jit
def run(iter):
    rng=random.PRNGKey(iter)
    a = random.normal(rng, shape=[B, N, N])
    a = a @ jnp.transpose(a, [0, 2, 1])
    b = jnp.linalg.cholesky(a)
    return b


for _ in range(20):
    run(_)

start = time.time()
run(21)
print(time.time() - start)

ssydasheng avatar Nov 27 '20 03:11 ssydasheng

Hi I think I've also been hitting a similar problem of silent recompilations that eat up a lot of time. Try enabling JAX to log every time it triggers a recompilation and check that it is triggering only once for you on the first iteration. i.e.

JAX_LOG_COMPILES=1 python test.py

There was a bit in the documentation I read about "lists not being hashable, while tuples are" so potentially the way that you pass lists to random.normal for shape and to jnp.transpose for the axis order are causing those functions to always be compiled every iteration. Changing these to tuples using the (,) syntax may fix this if it is silently recompiling.

JossWhittle avatar Nov 27 '20 10:11 JossWhittle

Hi thanks for the suggestion. I tried adding os.environ["JAX_LOG_COMPILES"] = "1", both CPU and GPU output the compilations signal only once. I also tried changing the lists to tuples, but it made no differences either.

There is something more strange. I replaced run(21) with for _ in range(21, 100): run(_). When B=1, the printing time is CPU=21.4955 vs GPU=0.5453. When B=3, the printing time is CPU=70.6205 vs GPU=1.4233.

ssydasheng avatar Nov 27 '20 21:11 ssydasheng

I wonder if the computation has actually finished at the time it returns to python control flow. Can you try changing the invocation to be run(_).block_until_ready() as suggested here https://jax.readthedocs.io/en/latest/async_dispatch.html

JossWhittle avatar Nov 27 '20 21:11 JossWhittle

You are exactly right ! After I set run(_).block_until_ready(), when B=1, the printing time is CPU=24.23 vs GPU=0.544; when B=3, the printing time is CPU=80.69 vs GPU=1.43.

A final question, is it possible to parallel these operations ? Ideally I expect the GPU should be faster by parallelisms when B>1 . However, for B=1,3,10, the time is 0.544, 1.43, 3.918, which I think not much a speed up.

ssydasheng avatar Nov 27 '20 22:11 ssydasheng

Currently on GPU we do use batched algorithms for some matrix operations (e.g. Cholesky decomposition), but only for small matrices. Ultimately we're just calling Cusolver here at the moment.

On CPU more parallelism is certainly possible, although it seems your interest is on GPU.

In principle on GPU we could run multiple operations in parallel (using multiple GPU streams), although it's not something we support at the moment. It's also not obvious to me whether it would be profitable for a 2048x2048 matrix, although I don't know without trying it.

I expect support for multiple GPU streams is likely to be something we can experiment with more seriously in six or so months after some runtime and compiler changes have landed.

hawkinsp avatar Nov 30 '20 14:11 hawkinsp

Thanks for your clarification !

ssydasheng avatar Nov 30 '20 16:11 ssydasheng

Thank you again @hawkinsp for the explanations. I wonder if you have any suggestions for me to speed up the batch Cholesky decompositions (~1000x1000). Cholesky decompositions take quadratic memory usage, thus I reckon the GPU memory will be sufficient if the batch is not quite large. Why do you think there might not be benefits from parallelisms ?

To speed it up, it seems that pmap is one potential approach when multi-GPUs are available. If only one GPU is available to me, is the jax.numpy.cholesky already the fastest, or there are some workarounds ? Any suggestions will be very helpful, thanks.

ssydasheng avatar Dec 07 '20 17:12 ssydasheng

We do not think of any war around for single GPU. Any objection that we close this issue?

nouiz avatar Sep 26 '22 20:09 nouiz

For a bit of context, below is what we obtained with pytest-benchmark using an A100. cuSolver's implementation of Cholesky decomposition already parallelizes w.r.t. batch size.

A modified version of OP's snippet:

import time
import pytest
from jax import jit, random
import numpy as np
import jax.numpy as jnp


@jit
def chol(A):
    return jnp.linalg.cholesky(A)


@pytest.mark.parametrize('b', [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024])
@pytest.mark.parametrize('N', [2048])
def test_chol_batch_size(benchmark, b, N):
    rng=random.PRNGKey(0)
    A = random.normal(rng, shape=[b, N, N])
    A = A @ jnp.transpose(A, [0, 2, 1])

    # warm up
    chol(A).block_until_ready()
    chol(A).block_until_ready()
    chol(A).block_until_ready()

    benchmark(lambda: chol(A).block_until_ready())


@pytest.mark.parametrize('b', [1])
@pytest.mark.parametrize('N', [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072])
def test_chol_matrix_size(benchmark, b, N):
    rng=random.PRNGKey(0)
    A = random.normal(rng, shape=[b, N, N])
    A = A @ jnp.transpose(A, [0, 2, 1])

    # warm up
    chol(A).block_until_ready()
    chol(A).block_until_ready()
    chol(A).block_until_ready()

    benchmark(lambda: chol(A).block_until_ready())

Result:

------------------------------------------- benchmark 'test_chol_batch_size': 10 tests ------------------------------------------
Name (time in ms)                    Median                Mean            StdDev               IQR            Iterations  Rounds
---------------------------------------------------------------------------------------------------------------------------------
test_chol_batch_size[2048-1]         3.2143 (1.0)        3.8463 (1.0)      1.0606 (100.10)   2.4045 (224.23)            1     178
test_chol_batch_size[2048-2]         5.0583 (1.57)       5.0516 (1.31)     0.0344 (3.24)     0.0428 (3.99)              1     200
test_chol_batch_size[2048-4]         5.9490 (1.85)       5.9524 (1.55)     0.0141 (1.33)     0.0182 (1.69)              1     170
test_chol_batch_size[2048-8]         7.6304 (2.37)       7.6320 (1.98)     0.0106 (1.0)      0.0107 (1.0)               1     132
test_chol_batch_size[2048-16]       10.7962 (3.36)      10.8420 (2.82)     0.2949 (27.83)    0.1056 (9.84)              1      94
test_chol_batch_size[2048-32]       17.9783 (5.59)      18.0078 (4.68)     0.2234 (21.08)    0.3887 (36.25)             1      57
test_chol_batch_size[2048-64]       33.2420 (10.34)     33.2182 (8.64)     0.3663 (34.57)    0.4239 (39.53)             1      31
test_chol_batch_size[2048-128]      64.9609 (20.21)     65.0094 (16.90)    0.2303 (21.74)    0.2981 (27.80)             1      16
test_chol_batch_size[2048-256]     128.6795 (40.03)    128.7425 (33.47)    0.3383 (31.93)    0.4527 (42.22)             1       8
test_chol_batch_size[2048-512]     262.0506 (81.53)    261.9398 (68.10)    2.7970 (263.98)   3.1338 (292.24)            1       5
---------------------------------------------------------------------------------------------------------------------------------

------------------------------------------- benchmark 'test_chol_matrix_size': 6 tests ------------------------------------------
Name (time in ms)                    Median                Mean            StdDev               IQR            Iterations  Rounds
---------------------------------------------------------------------------------------------------------------------------------
test_chol_matrix_size[1024-1]        1.4406 (1.0)        1.7101 (1.0)      0.4613 (18.81)    0.9990 (196.38)            1     361
test_chol_matrix_size[2048-1]        3.2106 (2.23)       3.2121 (1.88)     0.0245 (1.0)      0.0051 (1.0)               1     312
test_chol_matrix_size[4096-1]        7.7203 (5.36)       7.7340 (4.52)     0.0312 (1.27)     0.0433 (8.51)              1     130
test_chol_matrix_size[8192-1]       24.5720 (17.06)     24.6321 (14.40)    0.2355 (9.61)     0.0491 (9.65)              1      41
test_chol_matrix_size[16384-1]     125.7637 (87.30)    125.7540 (73.54)    0.5370 (21.90)    0.8374 (164.63)            1       8
test_chol_matrix_size[32768-1]     830.0251 (576.15)   830.3184 (485.54)   1.0143 (41.36)    1.7149 (337.12)            1       5
---------------------------------------------------------------------------------------------------------------------------------

yhtang avatar Sep 27 '22 15:09 yhtang