DiffEqFlux.jl
DiffEqFlux.jl copied to clipboard
Full support to gpu cast in neural DE layers
I'd like to send to gpu a NeuralODE object (embedded neural network included), instead of creating a NeuralODE out of a model already on gpu. Below an example:
julia> using DiffEqFlux, Flux
# This does the job
julia> node = NeuralODE(Dense(2,2) |> gpu, (0f0,1f0));
# everything lives on gpu
julia> for f in fieldnames(typeof(node)); println("$f => $(typeof(getfield(node, f)))") ; end
model => Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}
p => CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}
re => Flux.var"#122#123"{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}
tspan => Tuple{Float32, Float32}
args => Tuple{}
kwargs => Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}
# DESIRED FEATURE, currently a no-op
julia> node = NeuralODE(Dense(2,2), (0f0,1f0)) |> gpu;
# fields are still on cpu
julia> for f in fieldnames(typeof(node)); println("$f => $(typeof(getfield(node, f)))") ; end
model => Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}
p => Vector{Float32}
re => Flux.var"#122#123"{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}
tspan => Tuple{Float32, Float32}
args => Tuple{}
kwargs => Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}
# This is the usage in the example
# https://diffeqflux.sciml.ai/stable/examples/mnist_conv_neural_ode/
# The role of the outer `gpu` is not clear
julia> node = NeuralODE(Dense(2,2) |> gpu, (0f0,1f0)) |> gpu;
While normally one could get away with Flux.@functor NeuralODE, here since the internal field p, re, model are related, I think the appropriate solution would look something like this:
function Flux.gpu(node::NeuralOde)
model = gpu(node.model)
NeuralOdel(model, node.tspan)
end
That looks like a good solution. If someone made overloads like that for each of the neural architectures in here that would be an easy merge.