Batching at every gradient step
I am new to neural differential equations and have been going through some tutorials to better understand them. I noticed that in Python's Diffrax tutorial, they use a batching scheme for training, where every gradient step seems to be using 32 trajectories. This runs surprisingly fast, and when I tried to implement this in Julia, either via Optimization (setting maxiters=1 in solve) or via Lux.Training directly, it takes forever.
Am I totally misunderstanding something from the tutorial, or is this not a feature that is optimised for in any of the Julia packages that use DiffEqFlux? Thank you in advance!
Can you share your code?
Create data
# Times
n = 100 # sample size: i ∈ 1, ..., n
tspan = (0.0f0, 10f0)
t = range(tspan[1], tspan[2], length=n)
# Initial conditions
Random.seed!(0)
m = 2 # data dimensionality j ∈ 1, ..., m
p = 256 # number of sequences k ∈ 1, ..., p
Y0 = Float32.(rand(Uniform(-0.6, 1), (m, p))) # initial conditions
# Integrate true ODE
function truth!(du, u, p, t)
z = u ./ (1 .+ u)
du[1], du[2] = z[2], -z[1]
end
get_data(y0) = begin
ode = ODEProblem(truth!, y0, tspan)
y = solve(ode, Tsit5(), saveat=t)
Y = Array(y)
end
Y = cat(get_data.(eachcol(Y0))..., dims=3) # m × n × p
Create NODE
# Initial neural network
NN = Chain(Dense(m, 64, softplus), Dense(64, 64, softplus), Dense(64, m)) # (Lux) neural network NN: x ↦ ẋ
θ0, 𝒮 = Lux.setup(Xoshiro(0), NN) # initialise parameters θ
# Instantiate NeuralODE model
function neural_ode(NN, t)
node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)
end
Train
# Loss function
function L(NN, θ, 𝒮, (t, x0, y)) # Inputs: Lux model, params, state, data
node = neural_ode(NN, t)
x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)
L = sum(abs2, x - y)
L, 𝒮, NamedTuple() # Outputs: loss, state, stats
end
# Initialise training state
opt = AdamW(5e-3)
train_state = Lux.Training.TrainState(NN, ComponentArray(θ0), 𝒮, opt)
# Train one step
traj_size = 10
idx_traj = 1:traj_size
batch_size = 32
idx_batch = randperm(p)[1:batch_size]
∇θ, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), L, (t[idx_traj], Y0[:, idx_batch], Y[:, idx_traj, idx_batch]), train_state
)
This runs (not sure if it is correct), but even only with the first 10 time steps, it takes ~5 seconds each time with a batch of 32. I don't have the code using Optimization any more, but I remember it taking a long time because I rebuilt an optimisation problem inside the for loop (over each gradient step using a new random batch of trajectories).
x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)
You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into node like node(x0, theta, st) and drop the cat operation
x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)
You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into
nodelikenode(x0, theta, st)and drop the cat operation
I changed the loss function to:
# Loss function
function L(NN, θ, 𝒮, (t, x0, y)) # Inputs: Lux model, params, state, data
node = neural_ode(NN, t)
x = permutedims(Array(node(x0, θ, 𝒮)[1]), (1, 3, 2))
L = sum(abs2, x - y)
L, 𝒮, NamedTuple() # Outputs: loss, state, stats
end
and it decreased the training time to ~1 second, but that's still much slower than the Diffrax example (<0.01 seconds)... Do you reckon it would be better to use Optimization with an updated loss function (but still building and solving the optimisation problem at each step)? I could make a quick test example.
node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)
Diffrax is using much higher tolerances