DaggerFlux.jl icon indicating copy to clipboard operation
DaggerFlux.jl copied to clipboard

Don't assume the model is on a CUDA device

Open jpsamaroo opened this issue 3 years ago • 3 comments

Currently, DaggerChain communicates to Dagger that the wrapped model is located on a CUDA GPU, which is not necessarily true (and shouldn't be a requirement anyway). We should provide functions which can move the model to the GPU and communicate the correct location to Dagger, and/or auto-detect where a model currently resides.

jpsamaroo avatar Apr 02 '22 14:04 jpsamaroo

https://github.com/FluxML/DaggerFlux.jl/blob/101c23b86dd46244070401454db613067200ff5b/src/dflux.jl#L11 is where we assert that something is on the GPU. We can remove the CUDA part out safely now I believe?

DhairyaLGandhi avatar Apr 04 '22 15:04 DhairyaLGandhi

If we remove it, then we won't get automatic GPU execution (because GPU execution is disabled by default). We probably need a dispatch-based API in Dagger to enable GPU execution for certain functions (like DaggerChain).

jpsamaroo avatar Apr 04 '22 15:04 jpsamaroo

Do we then intend to revert the changes to daglayer in #19 before merging?

DhairyaLGandhi avatar Apr 07 '22 17:04 DhairyaLGandhi

I believe this has been resolved.

jpsamaroo avatar Jul 31 '23 14:07 jpsamaroo