jax icon indicating copy to clipboard operation
jax copied to clipboard

Pallas: transpose matmul throws segmentation fault

Open cgel opened this issue 11 months ago • 7 comments

Description

Transposing a matrix and doing a matrix vector product produces Segmentation fault (core dumped) but multiplying by the non transposed matrix works fine. Here is a small reproducer:

from functools import partial
import jax
import jax.experimental.pallas as pl
import jax.numpy as jnp

def bug_kernel(O_ref, B, d):
    A  = jnp.ones([B, B])
    x = A.T @ jnp.ones([B, d])  # fails  with "Segmentation fault (core dumped)"
    # x = A @ jnp.ones([B, d])  # works               
    pl.store(O_ref, jnp.arange(0, B), x)  

def call_buggy_kernel():
    B, d = 32 , 16
    grid = (1,)
    out = pl.pallas_call(
        partial(bug_kernel, B=B, d=d),
        grid=grid,
        out_shape = jax.ShapeDtypeStruct((B, d), jnp.float32),
    )()
call_buggy_kernel()

System info (python version, jaxlib version, accelerator, etc.)


jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.11.4 (main, Dec  7 2023, 15:43:41) [GCC 12.3.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='jacob-manifestai', release='6.2.0-39-generic', version='#40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023', machine='x86_64')


$ nvidia-smi
Sat Mar  9 09:44:04 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A6000               On  | 00000000:2D:00.0 Off |                  Off |
| 30%   26C    P2              25W / 300W |    269MiB / 49140MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               On  | 00000000:41:00.0 Off |                  Off |
| 30%   31C    P2              29W / 300W |    269MiB / 49140MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1820886      C   ...2-pallas-YfL-FD7B-py3.11/bin/python      262MiB |
|    1   N/A  N/A   1820886      C   ...2-pallas-YfL-FD7B-py3.11/bin/python      262MiB |
+---------------------------------------------------------------------------------------+

cgel avatar Mar 09 '24 14:03 cgel

Thanks, this looks like a Triton bug.

I am able to reproduce this on A100 internally. The error message coming from one of the Triton passes is getElemsPerThread is not supported for shared layout.

I will try to find a reproducer using Triton Python APIs.

superbobry avatar Mar 09 '24 20:03 superbobry

For what is worth, I moved to pallas to see if I could avoid this other bug I had found on triton which also gave the error

getElemsPerThread is not supported for shared layout

cgel avatar Mar 09 '24 20:03 cgel

hmm. I thought the other user mentioned that it's solved using the latest triton.

Jokeren avatar Mar 10 '24 03:03 Jokeren

@cgel can you check if the issue is still present in the Triton nightly? If yes, the next jaxlib release should have the upstream fixes as well.

superbobry avatar Mar 10 '24 10:03 superbobry

That other issue is not solved by the Triton nightly (which, as per their instructions, I got from here: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly), because the Triton nightly build is an older version of Triton (2.1.0) that does not include 3D matmul functionality at all.

It is fixed by building Triton from source, at the most recent commit to main. Will that make it into the next jaxlib release?

jbuckman avatar Mar 10 '24 13:03 jbuckman

@jbuckman likely yes, assuming that version of Triton is integrated into the openxla/triton repository by the next release.

hawkinsp avatar Mar 10 '24 15:03 hawkinsp

The nightly build has been broken for quite a while...sorry about that

Jokeren avatar Mar 10 '24 15:03 Jokeren

I confirmed this does not crash with the latest Triton nightly. Closing.

superbobry avatar Mar 20 '24 16:03 superbobry