DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

Full support to gpu cast in neural DE layers

Open CarloLucibello opened this issue 4 years ago • 1 comments

I'd like to send to gpu a NeuralODE object (embedded neural network included), instead of creating a NeuralODE out of a model already on gpu. Below an example:

julia> using DiffEqFlux, Flux

# This does the job
julia> node = NeuralODE(Dense(2,2) |> gpu, (0f0,1f0));

# everything lives on gpu
julia> for f in fieldnames(typeof(node)); println("$f => $(typeof(getfield(node, f)))") ; end
model => Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}
p => CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}
re => Flux.var"#122#123"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}
tspan => Tuple{Float32, Float32}
args => Tuple{}
kwargs => Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}


# DESIRED FEATURE,  currently a no-op
julia> node = NeuralODE(Dense(2,2), (0f0,1f0)) |> gpu;

# fields are still on cpu
julia> for f in fieldnames(typeof(node)); println("$f => $(typeof(getfield(node, f)))") ; end
model => Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}
p => Vector{Float32}
re => Flux.var"#122#123"{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
tspan => Tuple{Float32, Float32}
args => Tuple{}
kwargs => Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}

# This is the usage in the example 
#  https://diffeqflux.sciml.ai/stable/examples/mnist_conv_neural_ode/
# The role of the outer `gpu` is not clear
julia> node = NeuralODE(Dense(2,2) |> gpu, (0f0,1f0)) |> gpu;

While normally one could get away with Flux.@functor NeuralODE, here since the internal field p, re, model are related, I think the appropriate solution would look something like this:

function Flux.gpu(node::NeuralOde)
  model = gpu(node.model)
  NeuralOdel(model, node.tspan)
end

CarloLucibello avatar Oct 16 '21 11:10 CarloLucibello

That looks like a good solution. If someone made overloads like that for each of the neural architectures in here that would be an easy merge.

ChrisRackauckas avatar Oct 17 '21 18:10 ChrisRackauckas