Optimisers.jl icon indicating copy to clipboard operation
Optimisers.jl copied to clipboard

Wrong model update for BatchNorm for some specific synthax

Open jeremiedb opened this issue 2 years ago • 5 comments

Package Version

v0.2.13

Julia Version

1.8.2

OS / Environment

Windows

Describe the bug

If using the opts, m = Optimisers.update!(opts, m, grads) within a loop with more than 1 iteration that doesn't return the model m when complete, the model is improperly updated.

Steps to Reproduce

Defining a minimal custom model that contains a BatchNorm, as well a toy data and mse loss:

using Optimisers
using Flux
using Flux: @functor
using Statistics
using Random: seed!

struct MyModel{N}
    bn::N
end
@functor MyModel
Flux.trainable(m::MyModel) = (bn = m.bn,)

function (m::MyModel)(x)
    fp = m.bn(x)
    return dropdims(mean(fp, dims = 1), dims = 1)
end

seed!(123)
seed!(123)
x1, x2 = rand(Float32, 3, 5), rand(Float32, 3, 5)
y1, y2 = rand(Float32, 5), rand(Float32, 5)
loss(m, x, y) = mean((vec(m(x)) .- y) .^ 2)

2 variations of a training loop (with and without opts, m assignements):

function fit_1!(m, loss, opts)
    seed!(123)
    for i = 1:2
        x, y = rand(Float32, 3, 5), rand(Float32, 5)
        grads = gradient(model -> loss(model, x, y), m)[1]
        opts, m = Optimisers.update!(opts, m, grads)
    end
    return nothing
end

function fit_2!(m, loss, opts)
    seed!(123)
    for i = 1:2
        x, y = rand(Float32, 3, 5), rand(Float32, 5)
        grads = gradient(model -> loss(model, x, y), m)[1]
        Optimisers.update!(opts, m, grads)
    end
    return nothing
end

m1 = MyModel(BatchNorm(3))
opt1 = Optimisers.Adam()
opts1 = Optimisers.setup(opt1, m1)

m2 = MyModel(BatchNorm(3))
opt = Optimisers.Adam()
opts2 = Optimisers.setup(opt, m2)

Expected Results

It would be expected that each loop result in identical models. However, it's not the case:

# first loop
julia> loss(m1, x1, y1)
0.07888316f0
julia> fit_1!(m1, loss, opts1)
julia> loss(m1, x1, y1)
0.08968687f0

# second loop
julia> loss(m2, x1, y1)
0.07888316f0
julia> fit_2!(m2, loss, opts2)
julia> loss(m2, x1, y1)
0.10831509f0

Note that if the loop had only a single iteration (for i in = 1:1), the models would then be identical.

Also, if the model is returned following the loop and assigned, then the 2 loops also behave the same:

function fit_1!(m, loss, opts)
    seed!(123)
    for i = 1:2
        x, y = rand(Float32, 3, 5), rand(Float32, 5)
        grads = gradient(model -> loss(model, x, y), m)[1]
        opts, m = Optimisers.update!(opts, m, grads)
    end
    return m
end

m1 = MyModel(BatchNorm(3))
opt1 = Optimisers.Adam()
opts1 = Optimisers.setup(opt1, m1)

julia> loss(m1, x1, y1)
0.07888316f0
julia> m1 = fit_1!(m1, loss, opts1);
julia> loss(m1, x1, y1)
0.10831509f0

The 0.10831509f0 is the same as the one obtained using fit_2!.

Observed Results

none

Relevant log output

none

jeremiedb avatar Dec 02 '22 06:12 jeremiedb

Really not sure if #2122 might be related (considering BatchNorm is also the operator causing the issue)

jeremiedb avatar Dec 02 '22 22:12 jeremiedb

Haven't had a chance to look. My guess is that somehow the mutation of the BatchNorm struct is lost by Optimisers. Would be interesting to know if https://github.com/FluxML/Flux.jl/pull/2127 changes this, as it mutates the tracking arrays directly instead.

mcabbott avatar Dec 03 '22 00:12 mcabbott

Haven't had a chance to look. My guess is that somehow the mutation of the BatchNorm struct is lost by Optimisers. Would be interesting to know if FluxML/Flux.jl#2127 changes this, as it mutates the tracking arrays directly instead.

image

Add a point. I use the latest version of Flux and Optimisers and it seems that this bug was fixed.

skyleaworlder avatar May 08 '23 06:05 skyleaworlder

We should add a test to make sure we don't have regressions

CarloLucibello avatar May 08 '23 11:05 CarloLucibello

We should add a test to make sure we don't have regressions

I see the related test was added in Flux. Does Optimiser need a test about update!?

skyleaworlder avatar May 08 '23 13:05 skyleaworlder