Optimisers.jl
Optimisers.jl copied to clipboard
Wrong model update for BatchNorm for some specific synthax
Package Version
v0.2.13
Julia Version
1.8.2
OS / Environment
Windows
Describe the bug
If using the opts, m = Optimisers.update!(opts, m, grads) within a loop with more than 1 iteration that doesn't return the model m when complete, the model is improperly updated.
Steps to Reproduce
Defining a minimal custom model that contains a BatchNorm, as well a toy data and mse loss:
using Optimisers
using Flux
using Flux: @functor
using Statistics
using Random: seed!
struct MyModel{N}
bn::N
end
@functor MyModel
Flux.trainable(m::MyModel) = (bn = m.bn,)
function (m::MyModel)(x)
fp = m.bn(x)
return dropdims(mean(fp, dims = 1), dims = 1)
end
seed!(123)
seed!(123)
x1, x2 = rand(Float32, 3, 5), rand(Float32, 3, 5)
y1, y2 = rand(Float32, 5), rand(Float32, 5)
loss(m, x, y) = mean((vec(m(x)) .- y) .^ 2)
2 variations of a training loop (with and without opts, m assignements):
function fit_1!(m, loss, opts)
seed!(123)
for i = 1:2
x, y = rand(Float32, 3, 5), rand(Float32, 5)
grads = gradient(model -> loss(model, x, y), m)[1]
opts, m = Optimisers.update!(opts, m, grads)
end
return nothing
end
function fit_2!(m, loss, opts)
seed!(123)
for i = 1:2
x, y = rand(Float32, 3, 5), rand(Float32, 5)
grads = gradient(model -> loss(model, x, y), m)[1]
Optimisers.update!(opts, m, grads)
end
return nothing
end
m1 = MyModel(BatchNorm(3))
opt1 = Optimisers.Adam()
opts1 = Optimisers.setup(opt1, m1)
m2 = MyModel(BatchNorm(3))
opt = Optimisers.Adam()
opts2 = Optimisers.setup(opt, m2)
Expected Results
It would be expected that each loop result in identical models. However, it's not the case:
# first loop
julia> loss(m1, x1, y1)
0.07888316f0
julia> fit_1!(m1, loss, opts1)
julia> loss(m1, x1, y1)
0.08968687f0
# second loop
julia> loss(m2, x1, y1)
0.07888316f0
julia> fit_2!(m2, loss, opts2)
julia> loss(m2, x1, y1)
0.10831509f0
Note that if the loop had only a single iteration (for i in = 1:1), the models would then be identical.
Also, if the model is returned following the loop and assigned, then the 2 loops also behave the same:
function fit_1!(m, loss, opts)
seed!(123)
for i = 1:2
x, y = rand(Float32, 3, 5), rand(Float32, 5)
grads = gradient(model -> loss(model, x, y), m)[1]
opts, m = Optimisers.update!(opts, m, grads)
end
return m
end
m1 = MyModel(BatchNorm(3))
opt1 = Optimisers.Adam()
opts1 = Optimisers.setup(opt1, m1)
julia> loss(m1, x1, y1)
0.07888316f0
julia> m1 = fit_1!(m1, loss, opts1);
julia> loss(m1, x1, y1)
0.10831509f0
The 0.10831509f0 is the same as the one obtained using fit_2!.
Observed Results
none
Relevant log output
none
Really not sure if #2122 might be related (considering BatchNorm is also the operator causing the issue)
Haven't had a chance to look. My guess is that somehow the mutation of the BatchNorm struct is lost by Optimisers. Would be interesting to know if https://github.com/FluxML/Flux.jl/pull/2127 changes this, as it mutates the tracking arrays directly instead.
Haven't had a chance to look. My guess is that somehow the mutation of the BatchNorm struct is lost by Optimisers. Would be interesting to know if FluxML/Flux.jl#2127 changes this, as it mutates the tracking arrays directly instead.
Add a point. I use the latest version of Flux and Optimisers and it seems that this bug was fixed.
We should add a test to make sure we don't have regressions
We should add a test to make sure we don't have regressions
I see the related test was added in Flux. Does Optimiser need a test about update!?