DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

Batching at every gradient step

Open mjyshin opened this issue 9 months ago • 5 comments

I am new to neural differential equations and have been going through some tutorials to better understand them. I noticed that in Python's Diffrax tutorial, they use a batching scheme for training, where every gradient step seems to be using 32 trajectories. This runs surprisingly fast, and when I tried to implement this in Julia, either via Optimization (setting maxiters=1 in solve) or via Lux.Training directly, it takes forever.

Am I totally misunderstanding something from the tutorial, or is this not a feature that is optimised for in any of the Julia packages that use DiffEqFlux? Thank you in advance!

mjyshin avatar Feb 18 '25 11:02 mjyshin

Can you share your code?

ChrisRackauckas avatar Feb 18 '25 13:02 ChrisRackauckas

Create data

# Times
n = 100    # sample size: i ∈ 1, ..., n
tspan = (0.0f0, 10f0)
t = range(tspan[1], tspan[2], length=n)

# Initial conditions
Random.seed!(0)
m = 2    # data dimensionality j ∈ 1, ..., m
p = 256    # number of sequences k ∈ 1, ..., p
Y0 = Float32.(rand(Uniform(-0.6, 1), (m, p)))    # initial conditions

# Integrate true ODE
function truth!(du, u, p, t)
    z = u ./ (1 .+ u)
    du[1], du[2] = z[2], -z[1]
end
get_data(y0) = begin
    ode = ODEProblem(truth!, y0, tspan)
    y = solve(ode, Tsit5(), saveat=t)
    Y = Array(y)
end
Y = cat(get_data.(eachcol(Y0))..., dims=3)    # m × n × p

Create NODE

# Initial neural network
NN = Chain(Dense(m, 64, softplus), Dense(64, 64, softplus), Dense(64, m))    # (Lux) neural network NN: x ↦ ẋ
θ0, 𝒮 = Lux.setup(Xoshiro(0), NN)    # initialise parameters θ

# Instantiate NeuralODE model
function neural_ode(NN, t)
    node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)
end

Train

# Loss function
function L(NN, θ, 𝒮, (t, x0, y))    # Inputs: Lux model, params, state, data
    node = neural_ode(NN, t)
    x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)
    L = sum(abs2, x - y)
    L, 𝒮, NamedTuple()    # Outputs: loss, state, stats
end

# Initialise training state
opt = AdamW(5e-3)
train_state = Lux.Training.TrainState(NN, ComponentArray(θ0), 𝒮, opt)

# Train one step
traj_size = 10
idx_traj = 1:traj_size
batch_size = 32
idx_batch = randperm(p)[1:batch_size]
∇θ, loss, stats, train_state = Training.single_train_step!(
    AutoZygote(), L, (t[idx_traj], Y0[:, idx_batch], Y[:, idx_traj, idx_batch]), train_state
)

This runs (not sure if it is correct), but even only with the first 10 time steps, it takes ~5 seconds each time with a batch of 32. I don't have the code using Optimization any more, but I remember it taking a long time because I rebuilt an optimisation problem inside the for loop (over each gradient step using a new random batch of trajectories).

mjyshin avatar Feb 18 '25 14:02 mjyshin

x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)

You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into node like node(x0, theta, st) and drop the cat operation

avik-pal avatar Feb 18 '25 15:02 avik-pal

x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)

You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into node like node(x0, theta, st) and drop the cat operation

I changed the loss function to:

# Loss function
function L(NN, θ, 𝒮, (t, x0, y))    # Inputs: Lux model, params, state, data
    node = neural_ode(NN, t)
    x = permutedims(Array(node(x0, θ, 𝒮)[1]), (1, 3, 2))
    L = sum(abs2, x - y)
    L, 𝒮, NamedTuple()    # Outputs: loss, state, stats
end

and it decreased the training time to ~1 second, but that's still much slower than the Diffrax example (<0.01 seconds)... Do you reckon it would be better to use Optimization with an updated loss function (but still building and solving the optimisation problem at each step)? I could make a quick test example.

mjyshin avatar Feb 18 '25 16:02 mjyshin

    node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)

Diffrax is using much higher tolerances

avik-pal avatar Feb 18 '25 18:02 avik-pal