NNlib.jl
NNlib.jl copied to clipboard
Better errors for un-implemented functions
trafficstars
Functions like gather/scatter give scalar indexing errors if used on CuArrays without remembering to load NNlibCUDA.
Since there is now a very lightweight GPUArraysCore, I think NNlib should depend on that, and define dumb methods on ::AbstractGPUArray which throw a helpful error.
Edit, 2023:
After https://github.com/FluxML/NNlib.jl/pull/492 we use extensions. If I understand right it may still be possible to forget to load cuDNN.
Some kernels don't yet work for Metal / AMDGPU, e.g. https://github.com/FluxML/Flux.jl/issues/2278 . Here too it would be nicer to get a clear "not implemented" error.