Optimisers.jl icon indicating copy to clipboard operation
Optimisers.jl copied to clipboard

Handle gradients and object trees together

Open DhairyaLGandhi opened this issue 4 years ago • 12 comments

Flux has model structs, and Zygote would return NamedTuple gradients for them. With FluxML/Functors.jl#1 we add the ability to handle gradients in Functors.jl - in other words do a "zipped" tree walk on two equivalent functors and apply functions to them.

This extends the functionality to Optimisers.jl to allow for similar walks and apply apply (no pun intended).

Further, it moves towards https://github.com/FluxML/Optimisers.jl/pull/16#issuecomment-897415899 for a nicer API

DhairyaLGandhi avatar Aug 12 '21 07:08 DhairyaLGandhi

This is relevant for the DDP work too. Its one of the fixes that needed to go in to make it tick. NamedTuples are more reliably serialised compared to Grads which cause all sorts of issues with reference types.

CI will currently fail since we need to update the API in a lot of other places, including the rules themselves.

DhairyaLGandhi avatar Aug 12 '21 07:08 DhairyaLGandhi

I think the old interface matches https://github.com/FluxML/Functors.jl/pull/1 better, but I'm fine putting everything into a tuple. I will still suggest that the interface should be:

update(o, st, x, (dx1, dx2, ...))

darsnack avatar Aug 12 '21 13:08 darsnack

AFAICT the only missing part is functionality to convert a tree of gradients from Zygote into a tree of tuples. What are your thoughts about implementing that?

ToucheSir avatar Oct 28 '21 03:10 ToucheSir

Which part are you referring to exactly? We directly use the NamedTuple to parse the gradients. Are you referring to partials and higher order terms here? For those, I think we can wrap dx in a tuple, but those would need to be handled in apply right? The rules for all optimisers may not be generic to receiving tuples for gradients.

DhairyaLGandhi avatar Dec 07 '21 08:12 DhairyaLGandhi

Yes, higher-order terms in particular. The current interface on master puts the state first for update(o, state, x::T, x̄s...) and apply(o, state, x, dxs...). Those were changed to update(o, x::T, x̄, state) and apply(o, x, dx, state) in this PR. My understanding was that with this approach, and dx must be tuples for higher-order optimizers.

However, Zygote does not return gradients with each parameter pre-wrapped in a tuple. That leaves us with 2 choices using this state-last interface:

  1. fmap-wrap the gradients before passing them to update/apply. This would require apply methods to use dx[1], dx[2] etc. instead of varargs.
  2. Only (un)wrap first-order gradients. This would allow most optimizer apply methods to stay as-is (i.e. treat dx as an parameter instead of a 1-tuple), but anything second order or higher would be stuck with the aforementioned dx[...] pattern. It would also add another layer to the process: gradient(...) (unwrapped) -> wrap -> update(...) -> unwrap -> apply(...).

Alternatively, this PR could revert to a state-first interface. That would sidestep the need for wrapping and make full use of FluxML/Functors.jl#1's support for parallel object trees (Tuple{FunctorType{A}, FunctorType{B}} instead of FunctorType{Tuple{A,B}}}.

ToucheSir avatar Dec 08 '21 00:12 ToucheSir

2 is what we are building towards bit with only wrapping in case higher order terms are present. Apart from that i think we are good here

DhairyaLGandhi avatar Jan 16 '22 15:01 DhairyaLGandhi

What do we lose by sticking with the interface already on master? Why are we wanting to add additional steps to the design?

darsnack avatar Jan 16 '22 16:01 darsnack

I have not followed Optimisers.jl closely, but in trying to solve https://github.com/FluxML/Flux.jl/issues/1826 re-invented fmap(f, xs...)... and then started to wonder if it has the right behaviour. See examples here:

https://github.com/FluxML/Flux.jl/issues/1826#issuecomment-1013783952

In particular, it won't accumulate gradients for parameters which are repeated, and it won't handle branches of the tree with no gradient. I don't know whether it's behaviour on such things is what's intended, or a bug. There is about 1 test, which does not include such cases.

mcabbott avatar Jan 16 '22 16:01 mcabbott

Happy to add tests, but that comment is seemingly wanting something different entirely?

DhairyaLGandhi avatar Jan 16 '22 16:01 DhairyaLGandhi

No. The most relevant line is:

fmap(println, (x=[1,2], y=(a=[3,4], b=[5,6])), (x=[0,2], y=nothing))

This will happen all the time with gradients, since Zygote collapses nothings.

The second most relevant line is:

sh = [7,7]
sh2 = [0,7]
fmap(println, (x=sh, y=[3,4], z=sh), (x=sh2, y=sh2, z=[0,3]))

Since fmap explicitly has special handling for repeated leaves, a===b, it also needs to accumulate their gradients. Which can independently be === each other without consequence.

mcabbott avatar Jan 16 '22 16:01 mcabbott

That's because of the cache in fmap, right? fmap isn't even used here, so I don't think it is affected.

But that speaks to a wider problem about lack of consistency. Applying the optimizers should correspond to just fmaping the rule across the structure and gradients. Going for consistency from the outset means we will have less whack-a-mole with behavioral bugs.

darsnack avatar Jan 16 '22 16:01 darsnack

Yes, maybe fmap is off-topic. I have not tried hard to digest what this PR does.

But I don't see any attempt in this package to test tricky cases, and fmap is illustrative of whether they are being considered at all.

Can we use the move here to upgrade such things? Any open issue on Flux should be translated into a test here, of whatever tricky case is exposed, and a few more inspired by it. They can be marked broken for now.

mcabbott avatar Jan 16 '22 17:01 mcabbott