Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

ForwardDiff + destructure is different from Zygote, on a model with BatchNorm

Open lazarusA opened this issue 2 years ago • 7 comments

Package Version

Optimisers v0.2.10, ForwardDiff v0.10.33, Flux v0.13.7

Julia Version

1.8

OS / Environment

OS

Describe the bug

The example from Optimisers using ForwardDiff compared with the output from Zygote is unfortunately not the same.

Steps to Reproduce

using ForwardDiff  # an example of a package which only likes one array
using Flux
using Random
using Optimisers
Random.seed!(123)

model = Chain(  # much smaller model example, as ForwardDiff is a slow algorithm here
          Conv((3, 3), 3 => 5, pad=1, bias=false), 
          BatchNorm(5, relu), 
          Conv((3, 3), 5 => 3, stride=16),
        )
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));

loss(m, x) = sum(m(x))

rule = Optimisers.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7)

flat, re = Flux.destructure(model)
st = Optimisers.setup(rule, flat)  # state is just one Leaf now

∇flat = ForwardDiff.gradient(flat) do v
    loss(re(v), image) # re(v), rebuild a new object like model
end

st, flat = Optimisers.update(st, flat, ∇flat)
@show loss(re(flat),image);

and Zygote version

using Flux
using Random
Random.seed!(123)

model = Chain(  # much smaller model example, as ForwardDiff is a slow algorithm here
          Conv((3, 3), 3 => 5, pad=1, bias=false), 
          BatchNorm(5, relu), 
          Conv((3, 3), 5 => 3, stride=16),
        )
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));

loss(m, x) = sum(m(x))

opt = Flux.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7)
θ = Flux.params(model)
grads = Flux.gradient(θ) do 
    loss(model, image)
end

Flux.update!(opt, θ, grads)
@show loss(model, image);

Expected Results

sum(model(image)) = -0.33076355f0
loss(model, image) = -5.064876f0

Observed Results

sum(model(image)) = -0.33076355f0
loss(re(flat), image) = -7.7023053f0

Relevant log output

No response

lazarusA avatar Nov 21 '22 22:11 lazarusA

Can reproduce.

I note that commenting out BatchNorm removes the discrepancy.

And that inserting trainmode!(model) produces this error:

ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12})
Stacktrace:
...
 [12] _track_stats!(bn::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, μ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, σ²::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, reduce_dims::Vector{Int64})
    @ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:278
 [13] _norm_layer_forward(l::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}; reduce_dims::Vector{Int64}, affine_shape::NTuple{4, Int64})
    @ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:253
...

mcabbott avatar Nov 21 '22 22:11 mcabbott

For this last error, we'd need some way to pull the value out of the Duals passed to _track_stats!. Alternatively, is there some way to mark that function as non-differentiable?

ToucheSir avatar Nov 22 '22 00:11 ToucheSir

I think we can harmlessly just insert value into that broadcast. It's ForwardDiff-specific but maybe worth having?

For automatic train-mode, if we do something like https://github.com/FluxML/NNlib.jl/pull/434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it.

mcabbott avatar Nov 22 '22 01:11 mcabbott

So one fly in the ointment is that I was hoping to move track_stats! et al. out to NNlib soonish, which can't rely on ForwardDiff being loaded. Which is a good segue to

For automatic train-mode, if we do something like FluxML/NNlib.jl#434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it.

One path would be to be use Requires in NNlib to get non-CR ADs to conform to https://github.com/FluxML/NNlib.jl/pull/434. Another would be adding it to AbstractDifferentiation.jl, which already uses Requires for FD + RD + Tracker. Any other ideas I had (e.g. splitting off Dual numbers from ForwardDiff and having NNlib define methods on them) feel too far off to be feasible.

ToucheSir avatar Nov 22 '22 02:11 ToucheSir

I had to check but NNlib is much lighter than ForwardDiff, even if that moves to StaticArraysCore. But it does load Requires, so that might be fine:

julia> @time_imports using NNlib
      0.3 ms  Requires
      0.9 ms  DelimitedFiles
      0.3 ms  Compat
     80.0 ms  ChainRulesCore
      0.3 ms  Adapt
     24.3 ms  NNlib 55.88% compilation time (14% recompilation)

mcabbott avatar Nov 22 '22 02:11 mcabbott

AbstractDiff is pretty similar. Let me file an issue over there and see how it goes. We can always look into the NNlib option in parallel.

ToucheSir avatar Nov 24 '22 05:11 ToucheSir

Besides detecting whether you are within AD (also an issue for dropout), the problem with BatchNorm is that ForwardDiff runs the forward pass several times (chunked mode, for any large array).

I can't think of a good way to detect that. Perhaps we should make it a clear error instead?

mcabbott avatar Jan 27 '23 17:01 mcabbott