Optimisers.jl
Optimisers.jl copied to clipboard
Add `trainables_nt`
This is a proposal for an alternative to destructure which doesn't completely flatten the parameters but returns a nested named tuple. The associated reconstructor can be be used on ComponentArrays as well.
Keeping differentiability aside, is fmapstructure not sufficient because of how vectors are handled (e.g. layers in Chain)?
Exactly. And we need a nested namedtuple-only return in order to be compatible with ComponentArrays.
What about replacing destructure with this code + the ComponentArrays construction? As opposed to adding this as a separate function. It would move a lot of the tricky stuff to ComponentArrays.
@mcabbott
Wait there are two big differences from fmapstructure / Flux.state:
- this is only trainable parameters, and
- tuples & vectors become NamedTuples with made-up field names.
ComponentArrays has no notion of shared parameters. That's a large part of what makes everything touching Functors tricky. (In fact the replacement of a vector with a NamedTuple opens the door to weirdness here, before you get to ComponentArrays, as you replace a mutable thing with an immutable one. Probably not in a way that matters for Flux models.)
Example with this:
julia> sh = [1f0, 2f0];
julia> ps, re = trainables_nt((sh, sh, [3,4.]))
((_1 = Float32[1.0, 2.0], _2 = Float32[1.0, 2.0], _3 = [3.0, 4.0]), Optimisers.RestructureFromNT{Tuple{Vector{Float32}, Vector{Float32}, Vector{Float64}}}((Float32[1.0, 2.0], Float32[1.0, 2.0], [3.0, 4.0])))
julia> ps._1 === ps._2
true
julia> v = ComponentVector(ps);
julia> getfield(v, :data) |> println
[1.0, 2.0, 1.0, 2.0, 3.0, 4.0]
julia> v[3] = 99;
julia> re(v) # sharing is broken
([1.0, 2.0], [99.0, 2.0], [3.0, 4.0])
And unrelated to sharing:
julia> re(v)[1] |> eltype # accidental promotion is back
Float64
julia> re(v)[1] # no copy on reconstruction, but will view(::CuArray) work everywhere?
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
1.0
2.0
cf destructure:
julia> v2, re2 = destructure((sh, sh, [3,4.]))
([1.0, 2.0, 3.0, 4.0], Restructure(Tuple, ..., 4))
julia> v2[2] = 999;
julia> re2(v2)
(Float32[1.0, 999.0], Float32[1.0, 999.0], [3.0, 4.0])
When last I looked, ComponentArrays it also made more whole copies in the gradient.
More broadly, what's this for? Why do we care about ComponentArrays?
More broadly, what's this for? Why do we care about ComponentArrays?
I would like to have something in the v, re = destructure(model) style but for which reconstruction is copyless and it is also compatible with ComponentArrays. This is something that seems quite needed, see https://github.com/FluxML/Flux.jl/issues/2413#issuecomment-2033361707.
I think we can provide it and see if it is used.
I need help with the rrule of the reconstructor. It works for named tuples but not for component arrays:
using Zygote, Optimisers, ComponentArrays, Test
m = (collect(1:3.0), collect(4:6.0))
ps, re = trainables_nt(m)
Zygote.refresh()
gps = gradient(x -> re(x)[1][2], ps)[1]
@test gps == (_1 = [0.0, 1.0, 0.0], _2 = nothing). # ok
v = ComponentVector(ps)
gv = gradient(x -> re(x)[1][2], v)[1] # this is `nothing`!!!!
The relevant rule is
function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps)
model = restructure_from_nt(x, ps)
proj_ps = ProjectTo(ps)
function restructure_from_nt_back(Δmodel_raw)
Δmodel = unthunk(Δmodel_raw)
walk = RestructureFromNamedTupleBackWalk()
function exclude(x)
@show "exclude" x isnumeric(x)
# i += 1
# return i > 1
return isnumeric(x)
end
Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ
@show "fmap" Δ p
return Δ
end
Δpst = Tangent{typeof(Δps)}(; Δps...)
@show "rrule" Δmodel x ps Δps Δpst #here Δp = (_1 = [0.0, 1.0, 0.0], _2 = ChainRulesCore.ZeroTangent())
@show typeof(Δmodel) typeof(ps) typeof(Δps)
return (NoTangent(), NoTangent(), Δps)
# return (NoTangent(), NoTangent(), proj_ps(Δpst))
end
return model, restructure_from_nt_back
end
struct RestructureFromNamedTupleBackWalk <: AbstractWalk end
function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel)
@show 1 typeof(Δmodel) typeof(ps)
Δm = make_named_tuple(Δmodel)
@show 2 typeof(Δm) ps Δm
Δm === nothing && return nothing
Δm === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ps, Δm)
@show 3 typeof(Δmodel) typeof(Δm) typeof(y)
return y
end
Why do I get nothing gradient? Am I doing something wrong with the projection?