Optimisers.jl
Optimisers.jl copied to clipboard
Utility for walking a tree (e.g. gradients) w.r.t. a model
Motivation and description
Using trainable, we can walk a model and only apply a function to trainable parameters. But the gradient from Zygote is a named tuple without this information.
Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over gradient outside of the update context.
Possible Implementation
We can include a maptrainable(f, model, [gradient]) (or better name) function that maps a function w.r.t. the trainable parameters of model.
- If another tree like
gradientis passed, thenfis applied to the leaves ofgradient(i.e. approximatelyfmap(TrainableWalk(f), gradient, model)using the last argument to filter the walk). - If no other tree is passed, we just apply
ftomodel(this is a simple walk but maybe it is good for consistency).
Ideally, I think the implementation would underly update (i.e. update is maptrainable with f specialized to call apply).
Related: https://github.com/FluxML/Optimisers.jl/pull/57. We have proposals for map and reduce, but does it make sense to try for a mapreduce?
Agreed, with the ability to add more trees to call as described above.
If another tree like gradient is passed, then f is applied to the leaves of gradient (i.e. approximately fmap(TrainableWalk(f), gradient, model) using the last argument to filter the walk).
I think the most obvious t_mapreduce(f, r, model, grads) would always call f(x, dx), but take trainability from the model. The present fmap(f, xs, ys) always calls f(x,y):
julia> fmap(println, (a=1, b=2), (a="!", b="?"))
1!
2?
(a = nothing, b = nothing)
julia> sh = [1.0]; fmap(println, (a=sh, b=sh), (a="!", b="?"))
[1.0]!
(a = nothing, b = nothing)
The tricky bit as usual will be shared parameters. Here fmap simply ignores y belonging to a shared x. This fmap(f, xs, ys) is a half-baked feature, I think update! was the original target but it's not actually right for that.
The walk done by Optimisers.update! instead adds distinct dx belonging to shared x before calling apply!. I wonder how often that would be correct, e.g. for the gradient norm example it probably would be. To write update! (ignoring its return) you would need t_mapreduce(f, Returns(nothing), model, grads, state_tree) where we add dx but not state?
julia> Optimisers.setup(Momentum(), (a=sh, b=sh))
(a = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), b = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]))
julia> ans.a === ans.b
true
This all always feels like we have slightly the wrong abstractions.
For the norm use-case, another thing that would be handy is if I could destructure the gradient to flatten it, but only keep the trainable params as governed by the model. Then I can just take a norm directly on the flat vector.
Or maybe a more composable thing would be if I could walk the model & gradient simultaenously, and map non-trainable gradients to nothing, returning an updated gradient that only has non-nothing entries for trainable params. Then I could do whatever I wanted with that (walk it again with fmap, flatten it with destructure, etc).
A simpler version of this came up in conversation over RL models on Slack today. The current incantation for updating one model's parameters based on the moving average of another model's is:
for (t, p) in zip(Flux.params(target), Flux.params(policy))
t .= (1 - tau) .* t .+ tau .* p
end
To which I proposed:
Functors.fmap(m_target, m_policy; walk = Optimisers._Trainable_biwalk()) do t, p
t = (1 - tau) .* t .+ tau .* p
end
It should take no time to package up the latter as a mapparams function on our side. The questions are, where should it live (Flux or Optimisers) and what should be it called (e.g. maptrainable instead)?