NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

Cannot use BFGS with GPU

Open marcofrancis opened this issue 1 year ago • 10 comments

Hi, following the new tutorial I moved my GPU net from Flux to Lux. Unfortunately now I cannot use BFGS for the optimization. The code I use is the following:

using NeuralPDE, Lux, CUDA, ModelingToolkit, IntegralsCubature, QuasiMonteCarlo, Random
using Optimization, OptimizationOptimJL, OptimizationOptimisers
import ModelingToolkit: Interval

using Plots, JLD2

τ_min= 0.
τ_max = 1.0
τ_span = (τ_min,τ_max)

ω_min = -2.
ω_max = 2.

dω = 0.05
dt_NN = 0.01

μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ

γ = 1.5
k = 0.1
h(ω) = exp.(-(γ-1)*ω)


@parameters τ ω
@variables f(..)

Dω = Differential(ω)
Dωω = Differential(ω)^2
Dτ = Differential(τ)

μω = λ*(ω-ωb)
σω = σ

# PDE
eq  = 1/100*Dτ(f(τ,ω)) ~ 0.5*σω^2*Dωω(f(τ,ω)) - μω*Dω(f(τ,ω)) - k*f(τ,ω) + h(ω) 

# Initial and boundary conditions
bcs = [f(τ_min,ω) ~ h(ω),
       Dω(f(τ,ω_max)) ~ 0,
       Dω(f(τ,ω_min)) ~ 0]

# Space and time domains
domains = [τ ∈ Interval(τ_min,τ_max),
           ω ∈ Interval(ω_min,ω_max)]

τs,ωs = [infimum(d.domain):dω:supremum(d.domain) for d in domains]


@named pde_system = PDESystem(eq,bcs,domains,[τ,ω],[f(τ, ω)])

# NN parameters
dim =2
hls = dim+50

# Neural network

    chain = Chain(Dense(dim,hls,Lux.σ),
                  Dense(hls,hls,Lux.σ),
                  Dense(hls,hls,Lux.σ),
                  Dense(hls,1)) 
    ps = Lux.setup(Random.default_rng(), chain)[1]
    ps = ps |> Lux.ComponentArray |> gpu .|> Float64

    strategy = GridTraining([dω,dt_NN])

    discretization = PhysicsInformedNN(chain,
                                   strategy, init_params = ps)

    prob = discretize(pde_system,discretization)

    callback = function (p,l)
        println("Current loss is: $l")
        return false
    end

bfgs = OptimizationOptimJL.BFGS()
res = Optimization.solve(prob,bfgs ;callback = callback,maxiters=100)
prob = remake(prob,u0=res.u)

And I get this error:

ERROR: MethodError: no method matching unsafe_convert(::Type{Ptr{Float64}}, ::CuPtr{Float64})
Closest candidates are:
  unsafe_convert(::Type{RefOrCuRef{T}}, ::Union{CuPtr{T}, CUDA.CuRefArray{T, A} where A<:(AbstractArray{T})}) where T at C:\Users\MARCO\.julia\packages\CUDA\DfvRa\src\pointer.jl:264
  unsafe_convert(::Type{RefOrCuRef{T}}, ::Any) where T at C:\Users\MARCO\.julia\packages\CUDA\DfvRa\src\pointer.jl:260
  unsafe_convert(::Type{<:Union{CuArrayPtr, CuPtr, Ptr}}, ::CUDA.Mem.AbstractBuffer) at C:\Users\MARCO\.julia\packages\CUDA\DfvRa\lib\cudadrv\memory.jl:33

I attache the full trace. Stack_trace_BFGS.txt

marcofrancis avatar Aug 31 '22 20:08 marcofrancis

Does it work with Adam?

KirillZubov avatar Sep 01 '22 09:09 KirillZubov

it is working with adam and work with Flux. the problem in NeuralPDE, is somewhere is unsafe to convert into code.

here is MWE, which shows that everything works directly, so the problem is definitely in NeuralPDE

dim = 2
hls = 5
model = Lux.Chain(Lux.Dense(dim, hls, Lux.σ),
                  Lux.Dense(hls, hls, Lux.σ),
                  Lux.Dense(hls, hls, Lux.σ),
                  Lux.Dense(hls, 1))
initθ, re = Flux.destructure(chain)

x = rand(rng, Float32, 2, 2) |> Flux.gpu
y = rand(rng, Float32, 1, 2) |> Flux.gpu

loss(initθ, p) = sum(abs2, re(initθ)(x) - y)

loss(initθ, nothing)

f = OptimizationFunction(loss,
                         Optimization.AutoZygote())
prob = Optimization.OptimizationProblem(f, initθ)

bfgs = OptimizationOptimJL.BFGS()
res = Optimization.solve(prob, bfgs; callback = callback, maxiters = 10)

KirillZubov avatar Sep 01 '22 13:09 KirillZubov

Ok, so it's not only me. Any disadvantage in using Flux over Lux for GPU training?

marcofrancis avatar Sep 06 '22 16:09 marcofrancis

You need to overload mul!. See how I did it here

YichengDWu avatar Sep 06 '22 21:09 YichengDWu

Is there an issue or PR for upstreaming that to ComponentArrays?

ChrisRackauckas avatar Sep 07 '22 23:09 ChrisRackauckas

No. Mainly because I didn't know (I still don't know) how to take the adjoint of a GPUComponentArray, so I just took the adjoint of the data instead :sweat_smile:. I didn't dig into it, there should be a proper way to do this.

YichengDWu avatar Sep 08 '22 00:09 YichengDWu

It looks like it's a matter of displaying the adjoint

julia> ps
ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}(a = Float32[0.8377614 0.01191762 0.8278436 0.3350626; 0.7597569 0.4549542 0.018469304 0.0072024767; 0.09418385 0.1333899 0.42138302 0.89377856], b = Float32[0.059708346 0.16068034 0.72350216 0.41418508; 0.7101168 0.45254472 0.25748685 0.22498575; 0.07498102 0.50259334 0.71517354 0.5034971])

julia> ps_adj = ps';

julia> ps_adj
1×24 adjoint(::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}) with eltype Float32 with indices Base.OneTo(1)×1:1:24:
Error showing value of type Adjoint{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}}:
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/lojQM/src/GPUArraysCore.jl:87
  [3] getindex
    @ ~/.julia/packages/GPUArrays/fqD8z/src/host/indexing.jl:9 [inlined]
  [4] getindex
    @ ~/.julia/packages/ComponentArrays/EjZNJ/src/array_interface.jl:96 [inlined]
  [5] getindex
    @ /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:179 [inlined]
  [6] _getindex
    @ ./abstractarray.jl:1274 [inlined]
  [7] getindex
    @ ./abstractarray.jl:1241 [inlined]
  [8] isassigned(::Adjoint{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}}, ::Int64, ::Int64)
    @ Base ./abstractarray.jl:565
  [9] alignment(io::IOContext{Base.TTY}, X::AbstractVecOrMat, rows::Vector{Int64}, cols::Vector{Int64}, cols_if_complete::Int64, cols_otherwise::Int64, sep::Int64, ncols::Int64)
    @ Base ./arrayshow.jl:68
 [10] _print_matrix(io::IOContext{Base.TTY}, X::AbstractVecOrMat, pre::String, sep::String, post::String, hdots::String, vdots::String, ddots::String, hmod::Int64, vmod::Int64, rowsA::UnitRange{Int64}, colsA::UnitRange{Int64})
    @ Base ./arrayshow.jl:207
 [11] print_matrix(io::IOContext{Base.TTY}, X::Adjoint{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}}, pre::String, sep::String, post::String, hdots::String, vdots::String, ddots::String, hmod::Int64, vmod::Int64) (repeats 2 times)
    @ Base ./arrayshow.jl:171
 [12] print_array
    @ ./arrayshow.jl:358 [inlined]
 [13] show(io::IOContext{Base.TTY}, #unused#::MIME{Symbol("text/plain")}, X::Adjoint{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}})
    @ Base ./arrayshow.jl:399
 [14] display(d::REPL.REPLDisplay{REPL.LineEditREPL}, mime::MIME{Symbol("text/plain")}, x::Adjoint{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = ViewAxis(1:12, ShapedAxis((3, 4), NamedTuple())), b = ViewAxis(13:24, ShapedAxis((3, 4), NamedTuple())))}}}})
    @ OhMyREPL ~/.julia/packages/OhMyREPL/oDZvT/src/output_prompt_overwrite.jl:8
 [15] display(d::REPL.REPLDisplay, x::Any)
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:272
 [16] display(x::Any)
    @ Base.Multimedia ./multimedia.jl:328
 [17] #invokelatest#2
    @ ./essentials.jl:729 [inlined]
 [18] invokelatest
    @ ./essentials.jl:726 [inlined]
 [19] print_response(errio::IO, response::Any, show_value::Bool, have_color::Bool, specialdisplay::Union{Nothing, AbstractDisplay})
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:296
 [20] (::REPL.var"#45#46"{REPL.LineEditREPL, Pair{Any, Bool}, Bool, Bool})(io::Any)
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:278
 [21] with_repl_linfo(f::Any, repl::REPL.LineEditREPL)
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:521
 [22] print_response(repl::REPL.AbstractREPL, response::Any, show_value::Bool, have_color::Bool)
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:276
 [23] (::REPL.var"#do_respond#66"{Bool, Bool, REPL.var"#77#87"{REPL.LineEditREPL, REPL.REPLHistoryProvider}, REPL.LineEditREPL, REPL.LineEdit.Prompt})(s::REPL.LineEdit.MIState, buf::Any, ok::Bool)
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:857
 [24] #invokelatest#2
    @ ./essentials.jl:729 [inlined]
 [25] invokelatest
    @ ./essentials.jl:726 [inlined]
 [26] run_interface(terminal::REPL.Terminals.TextTerminal, m::REPL.LineEdit.ModalInterface, s::REPL.LineEdit.MIState)
    @ REPL.LineEdit /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/LineEdit.jl:2510
 [27] run_frontend(repl::REPL.LineEditREPL, backend::REPL.REPLBackendRef)
    @ REPL /opt/julias/julia-1.8.0/share/julia/stdlib/v1.8/REPL/src/REPL.jl:1248
 [28] (::REPL.var"#49#54"{REPL.LineEditREPL, REPL.REPLBackendRef})()
    @ REPL ./task.jl:484

YichengDWu avatar Sep 08 '22 16:09 YichengDWu

Many operations fallback to scalar indexing. If Adapt.jl isn't used properly and other overloads don't exist, then Adjoint wrapped abstract GPU arrays will keep hitting fallbacks. I assume this might happen in a few cases? But it's at least a start, so I would say it's worth upstreaming and it can keep improving over time.

ChrisRackauckas avatar Sep 08 '22 21:09 ChrisRackauckas

I'll try to do it if I have time. The new semester is taking up most of my time, unfortunately.

YichengDWu avatar Sep 08 '22 23:09 YichengDWu

Would be fixed by https://github.com/jonniedie/ComponentArrays.jl/pull/167

YichengDWu avatar Oct 17 '22 01:10 YichengDWu