ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Avoid device-to-host copy in `∇getindex!`

Open pxl-th opened this issue 1 year ago • 2 comments

Can we use custom kernel with atomics for ∇getindex!(dx::AbstractGPUArray, dy, inds...) instead of copying everything to CPU?

This way we'd be able to avoid synchronizations and we can add such kernel via extension

pxl-th avatar Jun 19 '24 20:06 pxl-th

To be clear the method which copies to CPU should only be for inds which are arrays, which is where you have to worry about races. For simpler things like A[1,:] or B[3:end-3] it should not do this.

I think this method was added as the simplest way to solve the problem. But having a faster kernel in a package extension would be fine. I believe it's a lot like NNlib.scatter.

mcabbott avatar Jun 19 '24 20:06 mcabbott

To be clear the method which copies to CPU should only be for inds which are arrays

Yes, that's exactly my situation :)

I can try to come up with a PR for this soon

pxl-th avatar Jun 19 '24 21:06 pxl-th