jax icon indicating copy to clipboard operation
jax copied to clipboard

[Pallas] NotImplementedError: Unimplemented primitive in Pallas GPU lowering: slice.

Open vvvm23 opened this issue 2 years ago • 6 comments

Related #18897

Reproducing code:

import jax
from jax.experimental import pallas as pl

# take first two elements
def kernel(x_ref, o_ref):
    x = x_ref[...]
    o_ref[...] = jax.lax.slice(x, (0,), (2,))

def call_kernel(x: jax.Array) -> jax.Array:
    return pl.pallas_call(kernel, out_shape=jax.ShapeDtypeStruct((2,), x.dtype))(x)

x = jax.numpy.arange(4)
y = call_kernel(x)

Output:

Traceback (most recent call last):
  File "/home/alex/work/pallas-test/reproducer.py", line 13, in <module>
    y = call_kernel(x)
  File "/home/alex/work/pallas-test/reproducer.py", line 10, in call_kernel
    return pl.pallas_call(kernel, out_shape=jax.ShapeDtypeStruct((2,), x.dtype))(x)
  File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1621, in pallas_call_lowering
    compilation_result = compile_jaxpr(
  File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 1564, in compile_jaxpr
    lowering_result = lower_jaxpr_to_triton_module(
  File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 280, in lower_jaxpr_to_triton_module
    () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
  File "/home/alex/.pyenv/versions/3.10.12/envs/venv/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py", line 323, in lower_jaxpr_to_triton_ir
    raise NotImplementedError(
NotImplementedError: Unimplemented primitive in Pallas GPU lowering: slice. Please file an issue on https://github.com/google/jax/issues.

vvvm23 avatar Jan 05 '24 13:01 vvvm23

Unfortunately there is no equivalent of slice in Triton so we don't currently have a way of lowering this right now. However, in your example, you could accomplish what you want by using pl.load w/ the correct indices.

sharadmv avatar Jan 08 '24 19:01 sharadmv

Makes sense, actually this is just a small reproducing example. I found the error initially trying to use jax.lax.associative_scan within a kernel. Do you think pl.load could be used to replace slices in this scan?

vvvm23 avatar Jan 08 '24 19:01 vvvm23

I think it could but it might not be super performant. The general question to ask is: can I implement this efficiently using Triton? We are limited by the sets of ops that Triton supports, so if you can't express in Triton it's unlikely you can express it in Pallas on GPU.

sharadmv avatar Jan 08 '24 21:01 sharadmv

@sharadmv I saw you directed me to this issue, but I don't understand how pl.load would solve my problem, is there some definition for this primitive and how to use it?

Thank you!

bsaoptima avatar Feb 29 '24 09:02 bsaoptima

We don't have support for slicing because Triton doesn't support the op afaik. However, you could imagine if you did something like this:


def foo(x_ref):
  x = x_ref[...]
  x_slice = x[0]

you could instead do this:

def foo(x_ref):
  x_slice = x_ref[0]

and that should work.

That's basically what pl.load means.

sharadmv avatar Feb 29 '24 23:02 sharadmv

In my code, it seems like the issue arises from the o_ref[...]=output which I'm not sure how to handle since I need to return my output. Majority of Pallas examples seems to have this so not sure how to return the output. Thanks for the help!

bsaoptima avatar Mar 01 '24 08:03 bsaoptima