ReverseDiff.jl
ReverseDiff.jl copied to clipboard
support for GPU-backed arrays
Originally, I thought we'd need a special macro directive for this, but I now believe this can just work by allowing the user to pass in GPU-backed arrays as input to the API methods. In other words, ReverseDiff won't provide a GPU-backed array type; it will compose with the GPU-backed array types provided by other well-established libraries.
For now, Julia's ArrayFire wrapper should be sufficient to start playing around with this functionality for a subset of ReverseDiff's methods - at least enough to build some proof-of-concept examples. Eventually, I'd like to support GPUArrays.jl, though it's probably not productive to work on that until GPUArrays has solid linear algebra coverage out-of-the-box.
In some cases, existing derivative definitions might "just work" for GPU-backed arrays (though they might have poor performance). In many cases, we'll likely have to dispatch on the array type to write special derivative methods for GPU-backed arrays.
cc the GPU-related folks: @SimonDanisch @vchuravy @maleadt @ranjanan @MikeInnes
Can we compile a list of operations that you need, to prioritize it in GPUArrays?
Can we compile a list of operations that you need, to prioritize it in GPUArrays?
Good idea. It would be nice to have a "master list" of supported TrackedArray functionality anyway.
How is the current situation of ReverseDiff.jl working with GPUArrays.jl?
Different examples are working, but we haven't attempted to run a complete testsuite or anything ;) Help to figure out what's covered and what needs improvement would be appreciated!