NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

[WIP] feat: compatibility of NNODE with CUDA

Open sathvikbhagavan opened this issue 2 weeks ago • 0 comments

Added a custom broadcast function for GPU with KA.jl. But it gives me an error while running solve.

MWE:

using Random, NeuralPDE
using OrdinaryDiffEq
using Lux, OptimizationOptimisers
using LuxCUDA, ComponentArrays, StaticArrays

rng = Random.default_rng()
Random.seed!(100)
const gpud = Lux.gpu_device()

function f(u, p, t)
    SVector{2}(p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2])
end
p = SVector{4}(1.5, 1.0, 3.0, 1.0)
u0 = SVector{2}(1.0, 1.0)
prob_oop = ODEProblem{false}(f, u0, (0.0, 3.0), p)
func = Lux.σ
N = 12
chain = Lux.Chain(
    Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
    Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))
opt = OptimizationOptimisers.Adam(0.01)
weights = [0.7, 0.2, 0.1]
points = 200
alg = NNODE(chain, opt; autodiff = false,
    strategy = NeuralPDE.WeightedIntervalTraining(weights, points), device = gpud)
sol = solve(prob_oop, alg, verbose = false, maxiters = 5000, saveat = 0.01)

gives me an error:

julia> sol = solve(prob_oop, alg, verbose = false, maxiters = 5000, saveat = 0.01)
┌ Warning: Mixed Precision Inputs received for `weight`: CuArray{Float32, 2, CUDA.DeviceMemory} and `x`: CuArray{Float64, 2, CUDA.DeviceMemory}. Promoting to Float64.
└ @ LuxLibCUDAExt ~/.julia/packages/LuxLib/Q3elb/ext/LuxLibCUDAExt/cublaslt.jl:27
ERROR: Compiling Tuple{CUDA.var"##cufunction#1171", Base.Pairs{Symbol, Union{Nothing, Bool}, Tuple{Symbol, Symbol}, @NamedTuple{always_inline::Bool, maxthreads::Nothing}}, typeof(cufunction), typeof(NeuralPDE.gpu_custom_broadcast!), Type{Tuple{KernelAbstractions.CompilerMetadata{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicCheck, Nothing, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}}}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, CuDeviceMatrix{Float64, 1}, CuDeviceMatrix{Float64, 1}, CuDeviceVector{Float64, 1}, CuDeviceVector{Float64, 1}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::CUDA.var"##cufunction#1171", ::@Kwargs{always_inline::Bool, maxthreads::Nothing}, ::typeof(cufunction), ::typeof(NeuralPDE.gpu_custom_broadcast!), ::Type{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
  [3] cufunction
    @ ~/.julia/packages/CUDA/75aiI/src/compiler/execution.jl:361 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{always_inline::Bool, maxthreads::Nothing}, ::typeof(cufunction), ::typeof(NeuralPDE.gpu_custom_broadcast!), ::Type{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [5] #_#4
    @ ~/.julia/packages/CUDA/75aiI/src/compiler/execution.jl:112 [inlined]
  [6] _pullback(::Zygote.Context{…}, ::CUDA.CUDAKernels.var"##_#4", ::Int64, ::Int64, ::KernelAbstractions.Kernel{…}, ::ODEFunction{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [7] _apply
    @ ./boot.jl:838 [inlined]
  [8] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [10] Kernel
    @ ~/.julia/packages/CUDA/75aiI/src/CUDAKernels.jl:89 [inlined]
 [11] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::KernelAbstractions.Kernel{…}, ::ODEFunction{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [12] rhs
    @ ~/NeuralPDE.jl/src/ode_solve.jl:236 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.rhs), ::LuxCUDADevice{…}, ::ODEFunction{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [14] inner_loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:225 [inlined]
 [15] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.inner_loss), ::NeuralPDE.ODEPhi{…}, ::ODEFunction{…}, ::Bool, ::Vector{…}, ::ComponentVector{…}, ::SVector{…}, ::Bool)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [16] loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:315 [inlined]
 [17] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#loss#200"{…}, ::ComponentVector{…}, ::NeuralPDE.ODEPhi{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] total_loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:425 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#total_loss#213"{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] _apply
    @ ./boot.jl:838 [inlined]
 [21] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [22] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [23] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/sakPO/src/scimlfunctions.jl:3762 [inlined]
 [24] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [25] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [26] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [27] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [28] #37
    @ ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationZygoteExt.jl:90 [inlined]
 [29] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#37#55"{OptimizationFunction{…}, OptimizationBase.ReInitCache{…}}, args::ComponentVector{Float32, CuArray{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [31] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [32] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [33] #39
    @ ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationZygoteExt.jl:93 [inlined]
 [34] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{…}}, args::ComponentVector{Float32, CuArray{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [35] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [36] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [37] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [38] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentVector{Float32, CuArray{…}, Tuple{…}}, ::ComponentVector{Float32, CuArray{…}, Tuple{…}})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/32Mb0/ext/OptimizationZygoteExt.jl:93
 [39] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [40] macro expansion
    @ ~/.julia/packages/Optimization/jWtfU/src/utils.jl:32 [inlined]
 [41] __solve(cache::OptimizationCache{OptimizationFunction{…}, OptimizationBase.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#210#214"{…}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [42] solve!(cache::OptimizationCache{OptimizationFunction{…}, OptimizationBase.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#210#214"{…}})
    @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:188
 [43] solve(::OptimizationProblem{true, OptimizationFunction{…}, ComponentVector{…}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, @Kwargs{}}, ::Adam; kwargs::@Kwargs{callback::NeuralPDE.var"#210#214"{…}, maxiters::Int64})
    @ SciMLBase ~/.julia/packages/SciMLBase/sakPO/src/solve.jl:96
 [44] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Float64, maxiters::Int64, tstops::Nothing)
    @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:466
 [45] __solve
    @ ~/NeuralPDE.jl/src/ode_solve.jl:358 [inlined]
 [46] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:612
 [47] solve_call
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:569 [inlined]
 [48] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1080 [inlined]
 [49] solve_up
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1066 [inlined]
 [50] #solve#51
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003 [inlined]
Some type information was truncated. Use `show(err)` to see complete types.

Broadcasting function works:

# This works
out = CuArray(rand(2, 100))
p_ = CuArray(p)
t = CuArray(rand(100))
du = similar(out)
NeuralPDE.gpu_broadcast(prob_oop.f, du, out, p_, t; workgroupsize = 64, ndrange = 100)
du

@ChrisRackauckas, I am not familiar on how to interpret this 😅, it says try-catch is not supported but there is no try catch, am I missing anything?

sathvikbhagavan avatar Jun 18 '24 06:06 sathvikbhagavan