Optimisers.jl
Optimisers.jl copied to clipboard
Handle gradients and object trees together
Flux has model structs, and Zygote would return NamedTuple gradients for them. With FluxML/Functors.jl#1 we add the ability to handle gradients in Functors.jl - in other words do a "zipped" tree walk on two equivalent functors and apply functions to them.
This extends the functionality to Optimisers.jl to allow for similar walks and apply apply (no pun intended).
Further, it moves towards https://github.com/FluxML/Optimisers.jl/pull/16#issuecomment-897415899 for a nicer API
This is relevant for the DDP work too. Its one of the fixes that needed to go in to make it tick. NamedTuples are more reliably serialised compared to Grads which cause all sorts of issues with reference types.
CI will currently fail since we need to update the API in a lot of other places, including the rules themselves.
I think the old interface matches https://github.com/FluxML/Functors.jl/pull/1 better, but I'm fine putting everything into a tuple. I will still suggest that the interface should be:
update(o, st, x, (dx1, dx2, ...))
AFAICT the only missing part is functionality to convert a tree of gradients from Zygote into a tree of tuples. What are your thoughts about implementing that?
Which part are you referring to exactly? We directly use the NamedTuple to parse the gradients. Are you referring to partials and higher order terms here? For those, I think we can wrap dx in a tuple, but those would need to be handled in apply right? The rules for all optimisers may not be generic to receiving tuples for gradients.
Yes, higher-order terms in particular. The current interface on master puts the state first for update(o, state, x::T, x̄s...) and apply(o, state, x, dxs...). Those were changed to update(o, x::T, x̄, state) and apply(o, x, dx, state) in this PR. My understanding was that with this approach, x̄ and dx must be tuples for higher-order optimizers.
However, Zygote does not return gradients with each parameter pre-wrapped in a tuple. That leaves us with 2 choices using this state-last interface:
fmap-wrap the gradients before passing them toupdate/apply. This would requireapplymethods to usedx[1],dx[2]etc. instead of varargs.- Only (un)wrap first-order gradients. This would allow most optimizer
applymethods to stay as-is (i.e. treatdxas an parameter instead of a 1-tuple), but anything second order or higher would be stuck with the aforementioneddx[...]pattern. It would also add another layer to the process:gradient(...) (unwrapped) -> wrap -> update(...) -> unwrap -> apply(...).
Alternatively, this PR could revert to a state-first interface. That would sidestep the need for wrapping and make full use of FluxML/Functors.jl#1's support for parallel object trees (Tuple{FunctorType{A}, FunctorType{B}} instead of FunctorType{Tuple{A,B}}}.
2 is what we are building towards bit with only wrapping in case higher order terms are present. Apart from that i think we are good here
What do we lose by sticking with the interface already on master? Why are we wanting to add additional steps to the design?
I have not followed Optimisers.jl closely, but in trying to solve https://github.com/FluxML/Flux.jl/issues/1826 re-invented fmap(f, xs...)... and then started to wonder if it has the right behaviour. See examples here:
https://github.com/FluxML/Flux.jl/issues/1826#issuecomment-1013783952
In particular, it won't accumulate gradients for parameters which are repeated, and it won't handle branches of the tree with no gradient. I don't know whether it's behaviour on such things is what's intended, or a bug. There is about 1 test, which does not include such cases.
Happy to add tests, but that comment is seemingly wanting something different entirely?
No. The most relevant line is:
fmap(println, (x=[1,2], y=(a=[3,4], b=[5,6])), (x=[0,2], y=nothing))
This will happen all the time with gradients, since Zygote collapses nothings.
The second most relevant line is:
sh = [7,7]
sh2 = [0,7]
fmap(println, (x=sh, y=[3,4], z=sh), (x=sh2, y=sh2, z=[0,3]))
Since fmap explicitly has special handling for repeated leaves, a===b, it also needs to accumulate their gradients. Which can independently be === each other without consequence.
That's because of the cache in fmap, right? fmap isn't even used here, so I don't think it is affected.
But that speaks to a wider problem about lack of consistency. Applying the optimizers should correspond to just fmaping the rule across the structure and gradients. Going for consistency from the outset means we will have less whack-a-mole with behavioral bugs.
Yes, maybe fmap is off-topic. I have not tried hard to digest what this PR does.
But I don't see any attempt in this package to test tricky cases, and fmap is illustrative of whether they are being considered at all.
Can we use the move here to upgrade such things? Any open issue on Flux should be translated into a test here, of whatever tricky case is exposed, and a few more inspired by it. They can be marked broken for now.