ChainRules.jl
ChainRules.jl copied to clipboard
Avoid device-to-host copy in `∇getindex!`
Can we use custom kernel with atomics for ∇getindex!(dx::AbstractGPUArray, dy, inds...) instead of copying everything to CPU?
This way we'd be able to avoid synchronizations and we can add such kernel via extension
To be clear the method which copies to CPU should only be for inds which are arrays, which is where you have to worry about races. For simpler things like A[1,:] or B[3:end-3] it should not do this.
I think this method was added as the simplest way to solve the problem. But having a faster kernel in a package extension would be fine. I believe it's a lot like NNlib.scatter.
To be clear the method which copies to CPU should only be for inds which are arrays
Yes, that's exactly my situation :)
I can try to come up with a PR for this soon