jax icon indicating copy to clipboard operation
jax copied to clipboard

Unimplemented primitive in Pallas GPU lowering: scatter-add

Open bsaoptima opened this issue 1 year ago • 1 comments

Description

I came before with some issues with slicing, which I now can manage thanks to all the help provided but now I have something else which is not working properly.

out = out.at[top:bottom, :].add(
        sliced_diag * sliced_other
    )

giving the error:

NotImplementedError: Unimplemented primitive in Pallas GPU lowering: scatter-add.

Now here I am just doing a simple add to an empty matrix, I don't see why a GPU could not handle it. Am I missing the problem here?

Many thanks!

System info (python version, jaxlib version, accelerator, etc.)

python=3.11.8 jax=0.4.25 jaxlib=0.4.25+cuda11.cudnn86 jaxtyping=0.2.25

bsaoptima avatar Mar 11 '24 13:03 bsaoptima

It might be tricky to support scatter-add in its entirety in Triton, but maybe we can support a handful of useful special-cases.

superbobry avatar Mar 15 '24 11:03 superbobry