NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Extending scatter! to work with CUDA sparse arrays

Open alonsoC1s opened this issue 2 months ago • 3 comments

Aims to fix #647 by extending the signature of scatter! to work with AbstractCuSparseArray, a CUDA array type notably excluded by the original method. With the proposed patch, calling scatter! with sparse arrays from CUDA.CUSPARSE will correctly call the CUDA-specialized method instead of calling the generic CPU method, which triggered a scalar indexing error. In my testing the existing CUDA kernels work perfectly fine with CuSparseArrayCSC.

The proposed implementation, perhaps inelegantly, just expands the types in the signature with Union{...}. I am open to discussing more beautiful ways of implementing this. Ideally, AbstractCuSparseArray would be a subtype of AnyCuArray.

PR Checklist

  • [x] Tests are added
  • [x] Documentation, if applicable

alonsoC1s avatar Nov 14 '25 15:11 alonsoC1s

The integration test with Lux fail because always_inliner! is not defined, and what look like Enzyme internal errors. Not sure if this is unrelated

alonsoC1s avatar Nov 15 '25 09:11 alonsoC1s

Seems fine. Is it possible to add a test on CI somehow, perhaps in https://github.com/FluxML/NNlib.jl/blob/master/test/ext_cuda/scatter.jl ?

mcabbott avatar Nov 19 '25 05:11 mcabbott

I added the sparse matrix varieties to the list of array types that are automatically tested

@mcabbott Any thoughts on making the implementation less ugly? Should I open an issue on CUDA.jl suggesting making CUSPARSE arrays subtypes of AnyCuArray?

alonsoC1s avatar Nov 25 '25 21:11 alonsoC1s

ops, I shouldn't have merged this @alonsoC1s , scatter tests are failing https://buildkite.com/julialang/nnlib-dot-jl/builds/1670/steps/canvas?jid=019abcf3-e322-4111-b8f6-9ef4d132aec0

CarloLucibello avatar Dec 21 '25 09:12 CarloLucibello