ChainRules.jl
ChainRules.jl copied to clipboard
Make `OneElement` more GPU friendly
Ref. https://github.com/FluxML/Flux.jl/pull/2368. I see a couple of possibly complementary ways to go about this. Easiest would be to define an Adapt rule for OneElement so it's materialized or substituted with some GPU-friendly equivalent when run through CUDA.cu. The other would be defining overloads for certain functions such as mul! which can take advantage of the sparsity.