NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

GPU out of memory

Open marcofrancis opened this issue 1 year ago • 8 comments

Hi, I translated my code to the new version of NeuralPDE, and I saw that the GPU example now uses Lux instead of Flux. Unfortunately, while in the previous version of the code I was able to run the model on my GPU now I get an out of GPU memory error. Any idea on why this might happen?

Here's the code I use:

using NeuralPDE, Lux, CUDA, ModelingToolkit, IntegralsCubature, QuasiMonteCarlo, Random
using Optimization, OptimizationOptimJL, OptimizationOptimisers

using Plots, JLD2
import ModelingToolkit: Interval

τ_min= 0.
τ_max = 1.0
τ_span = (τ_min,τ_max)

ω_min = -2.
ω_max = 2.

s_min = 0.
s_max = 1.

dω = 0.05
ds = 0.05

dt_NN = 0.01

μY = 0.03
σ = 0.03
λ = 0.1
ωb = μY/λ

ϕ = 0.1
sb = 0.5
v¹ = [1 0]
v² = [0 -1]
v = v¹-v²

γ = 1.5
δ = 0.01
h(ω,s) = exp.(-(γ-1)*ω).*s


@parameters τ ω s
@variables q(..) qₛ(..)

Dω = Differential(ω)
Dωω = Differential(ω)^2

Ds = Differential(s)
Dss = Differential(s)^2

Dωs = Differential(ω)*Differential(s)
Dτ = Differential(τ)

μω = λ*(ω-ωb)
σω² = σ^2

μs = ϕ*(s-sb)
σs² = s^2*(1-s)^2*sum(v.^2)

σω_σs = σ*s*(1-s)*sum([1 0]'*v)

# PDE
eq_NN = (1/100*Dτ(q(τ,ω,s)) ~ 0.5*σω²*Dωω(q(τ,ω,s)) + 0.5*σs²*Dss(q(τ,ω,s)) + σω_σs*Dωs(q(τ,ω,s))
           - μω*Dω(q(τ,ω,s)) -μs*Ds(q(τ,ω,s)) - δ*q(τ,ω,s) + h(ω,s))

# Boundary Conditions
bcs_NN = [q(τ_min,ω,s) ~ h(ω,s),
        Dω(q(τ,ω_max,s)) ~ 0,
        Dω(q(τ,ω_min,s)) ~ 0,
        Ds(q(τ,ω,s_min)) ~ 0,
        Ds(q(τ,ω,s_max)) ~ 0]
    

# Space and time domains
domains = [τ ∈ Interval(τ_min,τ_max),
           ω ∈ Interval(ω_min,ω_max),
           s ∈ Interval(s_min,s_max)]


τs,ωs,ss = [infimum(d.domain):dω:supremum(d.domain) for d in domains]


@named pde_system_NN = PDESystem(eq_NN,bcs_NN,domains,[τ,ω,s],[q(τ, ω, s)])

# NN parameters
dim =3
hls = dim+50

chain = Chain(Dense(dim,hls,Lux.σ),
            Dense(hls,hls,Lux.σ),
            Dense(hls,hls,Lux.σ),
            Dense(hls,1)) 
ps = Lux.setup(Random.default_rng(), chain)[1]
ps = ps |> Lux.ComponentArray |> gpu .|> Float64
    
strategy = GridTraining([dt_NN,dω,ds])

discretization = PhysicsInformedNN(chain,
                                   strategy, init_params = ps)
                                 
prob = discretize(pde_system_NN,discretization)

callback = function (p,l)
        println("Current loss is: $l")
        return false
end

 
res = Optimization.solve(prob,OptimizationOptimisers.Adam(0.01);callback = callback,maxiters=500)
prob = remake(prob,u0=res.u)

The error I get is the following:

ERROR: Out of GPU memory trying to allocate 68.158 MiB
Effective GPU memory usage: 100.00% (8.000 GiB/8.000 GiB)
Memory pool usage: 7.240 GiB (7.406 GiB reserved)

And I attach the full stack trace. Stack_trace_OOM.txt

marcofrancis avatar Aug 31 '22 19:08 marcofrancis

weights of NN in Float64, use Float32, that need less memory

KirillZubov avatar Sep 01 '22 09:09 KirillZubov

or do the number of weights less hls

KirillZubov avatar Sep 01 '22 09:09 KirillZubov

OK, float 32 did the trick. I'm a bit confused though, the net is quite small (<10K nodes) why is it taking the whole 8GB of my GPU? What am I getting wrong? Is it loading all the grid in memory? Doesn't it use mini batches to train?

marcofrancis avatar Sep 06 '22 16:09 marcofrancis

It depends on when the GC fires. If the GC doesn't fire before it's filled, then it will error.

@KirillZubov I think we should put an explicit GC call at the bottom of the loss function

ChrisRackauckas avatar Sep 06 '22 17:09 ChrisRackauckas

Any workaround I could use for the time being?

marcofrancis avatar Sep 07 '22 06:09 marcofrancis

try Flux, probably it is trouble with GC in Lux

KirillZubov avatar Sep 07 '22 11:09 KirillZubov

@marcofrancis It is only one batch by default now. Each batch must contain homogeneous data, i.e. a complete mesh, otherwise, it is a decomposition into subtasks https://neuralpde.sciml.ai/dev/tutorials/neural_adapter/#Domain-decomposition if the input data is not enough in one batch, then there will be low convergence of training, so many batches with a sparse grid - a bad solution.

Sufficiently dense mesh and max number of batches should be better convergence but require max memory that has on GPU.

KirillZubov avatar Sep 07 '22 12:09 KirillZubov

try Flux, probably it is trouble with GC in Lux

They are the same GC.

Any workaround I could use for the time being?

Add a GC.gc() call to a callback in the optimization. We should just add it to the bottom of the loss function

ChrisRackauckas avatar Sep 07 '22 23:09 ChrisRackauckas