Optimisers.jl
Optimisers.jl copied to clipboard
Consistency in the type behavior of restructure
This was discovered in https://github.com/SciML/NeuralPDE.jl/issues/533 as an issue that only showed itself as an incorrect gradient: the primal passes of what was being trained was in Float64, the reverse passes gave a Float64, the loss function print out give a Float64, and everything looked fine, except magically the Flux neural network was just "a bit more janky", in that it had a much higher probability of failing CI tests for a reason nobody could figure out for 5 months. Finally it was discovered that parts of the gradient were calculated in Float32 because the Flux.Chain had Float32 parameters in there. This showcased that re(p) does not "always" respect the types of p.
But it doesn't "always" respect the types of the Flux.Chain either. For example, for a standard Flux.Chain of Dense layers with Float32 parameters, you get:
re(p::Vector{Float64})computes in Float32re(p::CuVector{Float32})computes on the GPU in Float32re(p::Vector{Dual})computes with Dual numbersre(p::Vector{ComplexF32})computes with Float32
And now let's have some fun:
re(p::CuVector{Float64})computes ???. My guess is CuVector{Float32}?re(p::ReverseDiff.TrackedArray)computes ??? My guess is Array{TrackedReal{Float32}}?
I understand that this isn't intended behavior and comes out of some quirks about ProjectTo , that exposes some (IMO odd) behaviors of a ChainRules internal to users who are likely not experts in the autodiff system.
Now the problem that I have with it is that discovering this behavior is rather hard, because if you do anything other than the simplest "just use the neural network", almost any case will not expose to the user that this behavior exists. For example,
(p[end] .* re(p))::typeof(p)(p[end] .+ re(p))::typeof(p)- ...
so hold in the examples I described because the type demotion is countered by the type promotion that's applied by essentially any other computation that uses things with the eltype(p). Thus unless re(p) is the only operation that is used (in which case, you probably don't need to be using restructure/destructure), some other operation in the primal will mask the demotion and your forward pass will look like it computed using typeof(p). It will only present itself to a user in the gradient pass.
Thus I understand @mcabbott's reasoning behind saying it's not a gradient correctness issue (since it's correctly calculating the gradients of the object that is actually reconstructed), but I have now isolated many different cases that I thought were just "Flux janky behavior" and "I don't know why FastChain works here but Flux.Chain doesn't" all back to this same behavior. It may not be a gradient correctness issue, but it only presents itself as one in downstream libraries where I have found this, it only really exposes itself if you try to look into a seemingly incorrect gradient, and if it quacks like 🦆?
I understand that this behavior is now documented, but I'm not sure a behavior that presents itself like that is sufficiently handled just by documentation because it's hard to even figure out that something is going wrong without investigating the gradient calculation.
What could be done?
I would propose that we should just make the behavior undeniably straightforward and consistent. Either always make re(p) compute using values of typeof(p), or make it so it always computes using the values from the original Flux.Chain. Either choice is an easily explainable and predictable behavior. This middle ground is not easy to explain or predict.
Always matching p is the more predictable behavior in the Julia ecosystem. If you stick a complex number as the initial condition in the ODE solver, as the initial guess for a value in Optim, as the starting point for IterativeSolvers or NLsolve, etc. any generic code that I can think of, they will treat the computation in the sense that p provides. In many cases generic codes will just error if they can't handle it, but they try to compute using p. Non-generic codes immediately throw method errors describing what the allowed inputs are. I cannot think of another example in the Julia ecosystem where the "computation type" for f(p) does not match p or a fixed type, but instead match the internal types of the fields of f, only sometimes, other times it matches p.
If it always matches the Flux.Chain, at least that would be clearly visible since when you do it on a CuArray you see you get an Array and you're like oh, I see how this works. If I want to GPU, then I |> gpu the chain because it doesn't convert to p. Got it. With the current behavior, you see it re(p) works on the GPU, so okay why not just do re(p::Array{Float64}) as a quick way to convert to Float64? And if you think like that, you get burned.
The other behavior could be to throw an error in any case where a type conversion is necessary. If you want re(p::Array{Float64}) to work, go back and |> f64 the neural network. Now, this will cause some issues with making libraries work, but it's a nice (overly) safe option that would ensure there are no surprises.
Or, as @ToucheSir suggested, maybe these are two different functions, or two different options, and you should be required to choose which behavior you want. Some kind of re(p,Optimisers.NoConvert()) and re(p,Optimisers.Convert()).
Those 4 behaviors would be clear and easily predictable. I think the only option I would be adamantly against is the current behavior.
I would actually be in favour of behaviour 3: destructure is fundamentally a function that promises too much, and even after the effort made towards tightening that (+ improving correctness) when porting over to Optimisers, one could argue it still does. Disallowing promotion while destructuring would also reduce some internal complexity.
Now, another tricky thing is what to do about structured array types. Here I think we just have to enumerate as many weird cases as we can think of and come to an agreement on how to handle them all consistently. One such example:
julia> d = Dense(Diagonal(rand(Float32, 3)), false)
Dense(3 => 3; bias=false) # 9 parameters
julia> d.weight
3×3 Diagonal{Float32, Vector{Float32}}:
0.24043 ⋅ ⋅
⋅ 0.657887 ⋅
⋅ ⋅ 0.52947
julia> p, re = destructure(d)
( [1] = 0.24043
[5] = 0.657887
[9] = 0.52947, Restructure(Dense, ..., 9))
julia> p
9-element SparseArrays.SparseVector{Float32, Int64} with 3 stored entries:
[1] = 0.24043
[5] = 0.657887
[9] = 0.52947
julia> re(p)
Dense(3 => 3; bias=false) # 9 parameters
julia> re(p) |> dump
Dense{typeof(identity), Diagonal{Float32, SparseArrays.SparseVector{Float32, Int64}}, Bool}
weight: Diagonal{Float32, SparseArrays.SparseVector{Float32, Int64}}
diag: SparseArrays.SparseVector{Float32, Int64}
n: Int64 3
nzind: Array{Int64}((3,)) [1, 2, 3]
nzval: Array{Float32}((3,)) Float32[0.24042994, 0.6578865, 0.52947]
bias: Bool false
σ: identity (function of type typeof(identity))
And another one:
julia> d = Dense(rand(Float32, 3, 2), @SArray ones(3))
Dense(2 => 3) # 9 parameters
julia> p, re = destructure(d)
(Float32[0.9659148, -0.7210188, 0.20607175, 0.7583495, 0.35627228, -0.5444089, 0.0, 0.0, 0.0], Restructure(Dense, ..., 9))
julia> re(p)
Dense(2 => 3) # 9 parameters
julia> re(p) |> dump
Dense{typeof(identity), Matrix{Float32}, SizedVector{3, Float32, Vector{Float32}}}
weight: Array{Float32}((3, 2)) Float32[0.9659148 0.7583495; -0.7210188 0.35627228; 0.20607175 -0.5444089]
bias: SizedVector{3, Float32, Vector{Float32}}
data: Array{Float32}((3,)) Float32[0.0, 0.0, 0.0]
σ: identity (function of type typeof(identity))
julia> cu_p = cu(p)
9-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
0.9659148
-0.7210188
0.20607175
0.7583495
0.35627228
-0.5444089
0.0
0.0
0.0
julia> re(cu_p) |> dump
Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, SizedVector{3, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}
weight: CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
storage: CUDA.ArrayStorage{CUDA.Mem.DeviceBuffer}
buffer: CUDA.Mem.DeviceBuffer
ctx: CuContext
handle: Ptr{Nothing} @0x0000000002ab0400
valid: Bool true
ptr: CuPtr{Nothing} CuPtr{Nothing}(0x0000000701bc0800)
bytesize: Int64 24
async: Bool false
refcount: Base.Threads.Atomic{Int64}
value: Int64 1
maxsize: Int64 24
offset: Int64 0
dims: Tuple{Int64, Int64}
1: Int64 3
2: Int64 2
bias: SizedVector{3, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}
data: CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}
storage: CUDA.ArrayStorage{CUDA.Mem.DeviceBuffer}
buffer: CUDA.Mem.DeviceBuffer
ctx: CuContext
handle: Ptr{Nothing} @0x0000000002ab0400
valid: Bool true
ptr: CuPtr{Nothing} CuPtr{Nothing}(0x0000000701bc0a00)
bytesize: Int64 12
async: Bool false
refcount: Base.Threads.Atomic{Int64}
value: Int64 1
maxsize: Int64 12
offset: Int64 0
dims: Tuple{Int64}
1: Int64 3
σ: identity (function of type typeof(identity))
I repeat that no incorrect gradients have been displayed here. Calling other features you happen to dislike in some context gradient bugs is just muddying the waters. (There are known gradient bugs, they are marked "bug" in the issues here.)
Maybe it's helpful to understand what the goals are of the present design:
- Don't assume that all arrays have the same type. The optimisation rules don't need this, nor does
destructure. - If some parameters of the model are real, they must never be made complex. This is a correctness question.
- One of the use cases for
destructureis to make something ForwardDiff.jl can understand. ThusDualnumbers ought to propagate everywhere. - Try to preserve element types: Unintended promotion to Float64 is an easy source of major performance problems. (And, point 1, you are welcome to have different precisions in different parts of the model;
destructurewill assume that wasn't an accident.)
For 3., you may recall that https://github.com/FluxML/Optimisers.jl/pull/66 was precisely to address your complaint that Dual numbers did not propagate through some operations.
Since ReverseDiff.jl also likes flat arrays not nested trees, the same should go for its tracked arrays. If they don't propagate, I think that's a bug. But no need to guess. Tracker's arrays seem to work fine, something seems to make Vector{ReverseDiff.TrackedReal}, but surely that could be solved.
At present, this package does not know about GPU arrays, and thus makes no distinctions. If you think it's confusing that re from a CPU model can be fed a GPU array and construct a GPU model, it would not be very hard to forbid that. (Models with a mix of GPU and CPU arrays won't work very well right now. Various policies could be adopted, but nobody has got around to it.)
Re structured arrays, I suspect most of them should be marked @functor. I think you are suggesting that the sparse array outcome is undesirable, but I can't reproduce it on Julia nightly, so I suspect some of the weirdness about Diagonal being sometimes seen as sparse has gone away (with SparseArrays.jl moving out?)
julia> v, re = destructure(Diagonal([0.11, 0.22]));
julia> v
4-element Vector{Float64}:
0.11
0.0
0.0
0.22
julia> re([1.0, 2.0, 3.0, 4.0])
2×2 Diagonal{Float64, Vector{Float64}}:
1.0 ⋅
⋅ 4.0
This discards 2.0, 3.0 components of the new parameters. The zeros in v are structural, not accidental, so they are preserved.
I don't claim to know what the right answer is, so I posted those examples because it's not clear if they'd be considered consistent enough to pass muster. Another one is Transpose/Adjoint, which projects back to a dense array (AIUI it's a no-op) rather than the wrapper type.
On a meta level, I feel even more strongly now that the behaviour of destructure was way underspecified when it was first written. Not only is there disagreement about what counts as sufficiently "reconstructed" for various parameter types, but the proliferation of fancy array types in the ecosystem makes post-hoc specification (as we're attempting now) a significant undertaking. Here I'd be interested to know from @willtebbutt or others working packages like ParameterHandling.jl how they handle this particular can of worms :)
Ok. Adjoint should now reconstruct:
julia> destructure(rand(2)')[2]([1.0, 2.0])
1×2 adjoint(::Vector{Float64}) with eltype Float64:
1.0 2.0
julia> destructure(transpose(rand(2,2)))[2]([1, 2, 3, 4])
2×2 transpose(::Matrix{Float64}) with eltype Float64:
1.0 2.0
3.0 4.0
I agree that things are a bit under-specified. Like everything else in Julia really -- it's a bit of an exploration to see what properties turn out to be useful, and how to compose them.
I repeat that no incorrect gradients have been displayed here. Calling other features you happen to dislike in some context gradient bugs is just muddying the waters. (There are known gradient bugs, they are marked "bug" in the issues here.)
I don't disagree. There are no incorrect gradients here by the definition now in the docs. It's just an issue that only presents itself to downstream users via incorrect gradients (as demonstrated) in functions which expect to have the normal action that a generic Julia function generally has. It's a very subtle distinction. I agree it's not incorrect as documented, but it is also very hard to spot that it's happening in most cases (with demonstrations as to why)