[Pallas] NotImplementedError: Unimplemented primitive in Pallas GPU lowering: slice.
Related #18897
Reproducing code:
import jax
from jax.experimental import pallas as pl
# take first two elements
def kernel(x_ref, o_ref):
x = x_ref[...]
o_ref[...] = jax.lax.slice(x, (0,), (2,))
def call_kernel(x: jax.Array) -> jax.Array:
return pl.pallas_call(kernel, out_shape=jax.ShapeDtypeStruct((2,), x.dtype))(x)
x = jax.numpy.arange(4)
y = call_kernel(x)
Output:
Traceback (most recent call last):
File "/home/alex/work/pallas-test/reproducer.py", line 13, in <module>
y = call_kernel(x)
File "/home/alex/work/pallas-test/reproducer.py", line 10, in call_kernel
return pl.pallas_call(kernel, out_shape=jax.ShapeDtypeStruct((2,), x.dtype))(x)
File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1621, in pallas_call_lowering
compilation_result = compile_jaxpr(
File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1564, in compile_jaxpr
lowering_result = lower_jaxpr_to_triton_module(
File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 280, in lower_jaxpr_to_triton_module
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 323, in lower_jaxpr_to_triton_ir
raise NotImplementedError(
NotImplementedError: Unimplemented primitive in Pallas GPU lowering: slice. Please file an issue on https://github.com/google/jax/issues.
Unfortunately there is no equivalent of slice in Triton so we don't currently have a way of lowering this right now. However, in your example, you could accomplish what you want by using pl.load w/ the correct indices.
Makes sense, actually this is just a small reproducing example. I found the error initially trying to use jax.lax.associative_scan within a kernel. Do you think pl.load could be used to replace slices in this scan?
I think it could but it might not be super performant. The general question to ask is: can I implement this efficiently using Triton? We are limited by the sets of ops that Triton supports, so if you can't express in Triton it's unlikely you can express it in Pallas on GPU.
@sharadmv I saw you directed me to this issue, but I don't understand how pl.load would solve my problem, is there some definition for this primitive and how to use it?
Thank you!
We don't have support for slicing because Triton doesn't support the op afaik. However, you could imagine if you did something like this:
def foo(x_ref):
x = x_ref[...]
x_slice = x[0]
you could instead do this:
def foo(x_ref):
x_slice = x_ref[0]
and that should work.
That's basically what pl.load means.
In my code, it seems like the issue arises from the o_ref[...]=output which I'm not sure how to handle since I need to return my output. Majority of Pallas examples seems to have this so not sure how to return the output. Thanks for the help!