jax
jax copied to clipboard
Pallas: transpose matmul throws segmentation fault
Description
Transposing a matrix and doing a matrix vector product produces Segmentation fault (core dumped)
but multiplying by the non transposed matrix works fine. Here is a small reproducer:
from functools import partial
import jax
import jax.experimental.pallas as pl
import jax.numpy as jnp
def bug_kernel(O_ref, B, d):
A = jnp.ones([B, B])
x = A.T @ jnp.ones([B, d]) # fails with "Segmentation fault (core dumped)"
# x = A @ jnp.ones([B, d]) # works
pl.store(O_ref, jnp.arange(0, B), x)
def call_buggy_kernel():
B, d = 32 , 16
grid = (1,)
out = pl.pallas_call(
partial(bug_kernel, B=B, d=d),
grid=grid,
out_shape = jax.ShapeDtypeStruct((B, d), jnp.float32),
)()
call_buggy_kernel()
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.4 (main, Dec 7 2023, 15:43:41) [GCC 12.3.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='jacob-manifestai', release='6.2.0-39-generic', version='#40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023', machine='x86_64')
$ nvidia-smi
Sat Mar 9 09:44:04 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A6000 On | 00000000:2D:00.0 Off | Off |
| 30% 26C P2 25W / 300W | 269MiB / 49140MiB | 1% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA RTX A6000 On | 00000000:41:00.0 Off | Off |
| 30% 31C P2 29W / 300W | 269MiB / 49140MiB | 1% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1820886 C ...2-pallas-YfL-FD7B-py3.11/bin/python 262MiB |
| 1 N/A N/A 1820886 C ...2-pallas-YfL-FD7B-py3.11/bin/python 262MiB |
+---------------------------------------------------------------------------------------+
Thanks, this looks like a Triton bug.
I am able to reproduce this on A100 internally. The error message coming from one of the Triton passes is getElemsPerThread is not supported for shared layout
.
I will try to find a reproducer using Triton Python APIs.
For what is worth, I moved to pallas to see if I could avoid this other bug I had found on triton which also gave the error
getElemsPerThread is not supported for shared layout
hmm. I thought the other user mentioned that it's solved using the latest triton.
@cgel can you check if the issue is still present in the Triton nightly? If yes, the next jaxlib release should have the upstream fixes as well.
That other issue is not solved by the Triton nightly (which, as per their instructions, I got from here: pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
), because the Triton nightly build is an older version of Triton (2.1.0) that does not include 3D matmul functionality at all.
It is fixed by building Triton from source, at the most recent commit to main. Will that make it into the next jaxlib release?
@jbuckman likely yes, assuming that version of Triton is integrated into the openxla/triton repository by the next release.
The nightly build has been broken for quite a while...sorry about that
I confirmed this does not crash with the latest Triton nightly. Closing.