jax icon indicating copy to clipboard operation
jax copied to clipboard

simple pallas kernel hangs when input size exceeds some threshold

Open zhixuan-lin opened this issue 1 year ago • 2 comments

Description

The following simple pallas kernel that copies an array hangs indefinitely:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def copy_kernel(
    src,
    dst
):
    def body(i, carry):
        dst[i] = src[i]
        return carry

    _ = jax.lax.fori_loop(
        lower=0,
        upper=src.shape[0],
        body_fun=body,
        init_val=None
    )

@jax.jit
@jax.vmap
def copy_func(src):

    func = pl.pallas_call(
        f=copy_kernel,
        out_shape=jax.ShapeDtypeStruct(src.shape, src.dtype)
    )

    dst = func(src)
    return dst


if __name__ == '__main__':
    batch_size = 2 ** 16
    seq_length = 2 ** 16
    dtype = jnp.float32
    # dtype = jnp.bfloat16
    src = jnp.zeros((batch_size, seq_length), dtype=dtype)
    print(f'Array elements: {src.size}')
    print(f'Array size: {src.nbytes / 1e9:.4f}GB')
    dst = copy_func(src)
    dst.block_until_ready()

Program output:

2024-05-17 10:18:51.191896: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Array elements: 4294967296
Array size: 17.1799GB

I remember reading somewhere that the above warning can be ignored so I think it this is unlikely related to the issue I'm seeing.

It looks like as long as batch_size * seq_length <= 2 ** 31 then the program will not get stuck. For example, if I change either batch size or seq_length from 2 ** 16 to 2 ** 15 then it works fine. However, changing dtype from float32 to bfloat16 does not fix the problem. Plus I'm using A100 80GB, with batch_size = seq_length = 2 ** 16, dtype=float32 the array only takes roughly 17GB. So it perhaps has nothing to do with memory.

Also when it hangs both GPU and CPU utilization is zero.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cn-g020.server.mila.quebec', release='5.15.0-101-generic', version='#111-Ubuntu SMP Tue Mar 5 20:16:58 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri May 17 10:13:08 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:41:00.0 Off |                    0 |
| N/A   25C    P0              72W / 500W |    424MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   3144259      C   python                                      416MiB |
+---------------------------------------------------------------------------------------+

zhixuan-lin avatar May 17 '24 14:05 zhixuan-lin

I also tested jax==0.4.25 with cuda 11 with which I don't see the ptxas version warning (but the kernel still hangs indefinitely), so it likely has nothing to do with that

zhixuan-lin avatar May 17 '24 14:05 zhixuan-lin

It looks like something overflows and the loop iterates forever, but I'm not sure where the overflow actually happens.

superbobry avatar May 23 '24 20:05 superbobry

Hi @zhixuan-lin,

Testing the provided repro on a GCP VM with A100 40GB GPU with JAX 0.5.0 resulted in a RESOURCE EXHAUSTED error. Image

Setting the XLA_PYTHON_CLIENT_PREALLOCATE flag to 'false' and XLA_PYTHON_CLIENT_ALLOCATOR flag set to 'platform' resolved the issue. Image

For batch_size=2**16 and seq_length=2**16, the GPU memory usage appears to be around 33GB (see screenshot below). Image

Thank you.

rajasekharporeddy avatar Feb 05 '25 09:02 rajasekharporeddy