jax
jax copied to clipboard
Unimplemented primitive in Pallas GPU lowering: scatter-add
Description
I came before with some issues with slicing, which I now can manage thanks to all the help provided but now I have something else which is not working properly.
out = out.at[top:bottom, :].add(
sliced_diag * sliced_other
)
giving the error:
NotImplementedError: Unimplemented primitive in Pallas GPU lowering: scatter-add.
Now here I am just doing a simple add to an empty matrix, I don't see why a GPU could not handle it. Am I missing the problem here?
Many thanks!
System info (python version, jaxlib version, accelerator, etc.)
python=3.11.8 jax=0.4.25 jaxlib=0.4.25+cuda11.cudnn86 jaxtyping=0.2.25
It might be tricky to support scatter-add in its entirety in Triton, but maybe we can support a handful of useful special-cases.