jax
jax copied to clipboard
[pallas] Interpreter mismatch for masked OOB indexing
Description
For triton (if I have read this correctly) masked load/stores do not occur. So you can request to load/store to an index OOB for ref if that is masked. The current interpreter uses dynamic_slices/dynamic_slice_updates where masked updates are applied. In line with the 'always be in bounds' design in JAX if you index a slice that overruns the edge of the array it will be shifted to be valid (if possible). This leads to a disconnect in interpreter and Pallas outputs.
I know Triton is not Pallas, have you changed the desired behaviour for these cases in Pallas? - in which case this isn't a bug but needs documenting.
I've added a pull request fixing this with some tests https://github.com/google/jax/pull/21298
Here is a colab minimal reproduction with shifts in load indices.
import jax
from jax import numpy as jnp, jit
from jax.experimental import pallas as pl
def masked_load_pallas_kernel(x_ref, o_ref):
i = jnp.array(3)
mask = jnp.arange(x_ref.shape[0]) + i < x_ref.shape[0]
x = pl.load(x_ref, pl.dslice(i, mask.shape[0]), mask=mask, other=-1)
o_ref[:] = x
@partial(jit, static_argnames=('interpret',))
def masked_load(x: jax.Array, interpret: bool=True):
return pl.pallas_call(masked_load_pallas_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
interpret = interpret,
)(x)
x = jnp.arange(16)
print(f'Input:\nx:\n{x}\n\nOutput:')
for interpret in (True, False):
print(f'Interpret: {interpret}\n{masked_load(x, interpret=interpret)}')
Input:
x:
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Output:
Interpret: True
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 -1 -1 -1]
Interpret: False
[ 3 4 5 6 7 8 9 10 11 12 13 14 15 -1 -1 -1]
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='b70fe499e42d', release='6.1.58+', version='#1 SMP PREEMPT_DYNAMIC Sat Nov 18 15:31:17 UTC 2023', machine='x86_64')
$ nvidia-smi
Thu May 9 08:55:06 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA L4 Off | 00000000:00:03.0 Off | 0 |
| N/A 63C P0 30W / 72W | 17235MiB / 23034MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
(Problem persists in 0.4.28)