DaggerFlux.jl
DaggerFlux.jl copied to clipboard
Don't assume the model is on a CUDA device
Currently, DaggerChain communicates to Dagger that the wrapped model is located on a CUDA GPU, which is not necessarily true (and shouldn't be a requirement anyway). We should provide functions which can move the model to the GPU and communicate the correct location to Dagger, and/or auto-detect where a model currently resides.
https://github.com/FluxML/DaggerFlux.jl/blob/101c23b86dd46244070401454db613067200ff5b/src/dflux.jl#L11 is where we assert that something is on the GPU. We can remove the CUDA part out safely now I believe?
If we remove it, then we won't get automatic GPU execution (because GPU execution is disabled by default). We probably need a dispatch-based API in Dagger to enable GPU execution for certain functions (like DaggerChain).
Do we then intend to revert the changes to daglayer in #19 before merging?
I believe this has been resolved.