Flux.jl
Flux.jl copied to clipboard
ForwardDiff + destructure is different from Zygote, on a model with BatchNorm
Package Version
Optimisers v0.2.10, ForwardDiff v0.10.33, Flux v0.13.7
Julia Version
1.8
OS / Environment
OS
Describe the bug
The example from Optimisers using ForwardDiff compared with the output from Zygote is unfortunately not the same.
Steps to Reproduce
using ForwardDiff # an example of a package which only likes one array
using Flux
using Random
using Optimisers
Random.seed!(123)
model = Chain( # much smaller model example, as ForwardDiff is a slow algorithm here
Conv((3, 3), 3 => 5, pad=1, bias=false),
BatchNorm(5, relu),
Conv((3, 3), 5 => 3, stride=16),
)
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));
loss(m, x) = sum(m(x))
rule = Optimisers.Adam(0.001f0, (0.9f0, 0.999f0), 1.1920929f-7)
flat, re = Flux.destructure(model)
st = Optimisers.setup(rule, flat) # state is just one Leaf now
∇flat = ForwardDiff.gradient(flat) do v
loss(re(v), image) # re(v), rebuild a new object like model
end
st, flat = Optimisers.update(st, flat, ∇flat)
@show loss(re(flat),image);
and Zygote version
using Flux
using Random
Random.seed!(123)
model = Chain( # much smaller model example, as ForwardDiff is a slow algorithm here
Conv((3, 3), 3 => 5, pad=1, bias=false),
BatchNorm(5, relu),
Conv((3, 3), 5 => 3, stride=16),
)
image = rand(Float32, 224, 224, 3, 1);
@show sum(model(image));
loss(m, x) = sum(m(x))
opt = Flux.Adam(0.001f0, (0.9f0, 0.999f0), 1.1920929f-7)
θ = Flux.params(model)
grads = Flux.gradient(θ) do
loss(model, image)
end
Flux.update!(opt, θ, grads)
@show loss(model, image);
Expected Results
sum(model(image)) = -0.33076355f0
loss(model, image) = -5.064876f0
Observed Results
sum(model(image)) = -0.33076355f0
loss(re(flat), image) = -7.7023053f0
Relevant log output
No response
Can reproduce.
I note that commenting out BatchNorm removes the discrepancy.
And that inserting trainmode!(model) produces this error:
ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12})
Stacktrace:
...
[12] _track_stats!(bn::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, μ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, σ²::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, reduce_dims::Vector{Int64})
@ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:278
[13] _norm_layer_forward(l::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}; reduce_dims::Vector{Int64}, affine_shape::NTuple{4, Int64})
@ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:253
...
For this last error, we'd need some way to pull the value out of the Duals passed to _track_stats!. Alternatively, is there some way to mark that function as non-differentiable?
I think we can harmlessly just insert value into that broadcast. It's ForwardDiff-specific but maybe worth having?
For automatic train-mode, if we do something like https://github.com/FluxML/NNlib.jl/pull/434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it.
So one fly in the ointment is that I was hoping to move track_stats! et al. out to NNlib soonish, which can't rely on ForwardDiff being loaded. Which is a good segue to
For automatic train-mode, if we do something like FluxML/NNlib.jl#434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it.
One path would be to be use Requires in NNlib to get non-CR ADs to conform to https://github.com/FluxML/NNlib.jl/pull/434. Another would be adding it to AbstractDifferentiation.jl, which already uses Requires for FD + RD + Tracker. Any other ideas I had (e.g. splitting off Dual numbers from ForwardDiff and having NNlib define methods on them) feel too far off to be feasible.
I had to check but NNlib is much lighter than ForwardDiff, even if that moves to StaticArraysCore. But it does load Requires, so that might be fine:
julia> @time_imports using NNlib
0.3 ms Requires
0.9 ms DelimitedFiles
0.3 ms Compat
80.0 ms ChainRulesCore
0.3 ms Adapt
24.3 ms NNlib 55.88% compilation time (14% recompilation)
AbstractDiff is pretty similar. Let me file an issue over there and see how it goes. We can always look into the NNlib option in parallel.
Besides detecting whether you are within AD (also an issue for dropout), the problem with BatchNorm is that ForwardDiff runs the forward pass several times (chunked mode, for any large array).
I can't think of a good way to detect that. Perhaps we should make it a clear error instead?