jax icon indicating copy to clipboard operation
jax copied to clipboard

Pallas strided loads ignore the stride argument on CUDA backend

Open hirayaku opened this issue 10 months ago • 0 comments

Description

When I test strided loads in pallas kernels with CUDA backend, pallas.load seems to ignore step in the slice argument. For example, the following code should return [0, 4, 8, 12] but it actually prints [0, 1, 2, 3].

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

def strided(x_ref, o_ref):
    x = pl.load(x_ref, slice(0, None, 4))
    o_ref[:] = x

x = jnp.arange(16, dtype=jnp.uint32)
out = pl.pallas_call(
    strided, out_shape=jax.ShapeDtypeStruct((4,), x.dtype)
)(x)
print(out)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.12.3 | packaged by Anaconda, Inc. | (main, Apr 19 2024, 16:50:38) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='HomeLinux', release='6.1.0-9-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.1.27-1 (2023-05-08)', machine='x86_64')

$ nvidia-smi
Tue Apr 23 14:54:02 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:17:00.0 Off |                  Off |
|  0%   56C    P2    39W / 450W |    495MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

hirayaku avatar Apr 23 '24 18:04 hirayaku