Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Upgrade `train!` to work with explicit parameters

Open mcabbott opened this issue 2 years ago • 2 comments

This PR proposes to move away from implicit parameters not by simply deleting train!, but instead by re-writing it to use explicit mode. This means that implicit train! has an easy upgrade path, and the new explicit train! can later be changed to use something other than Zygote.

The option to use Optimisers.jl directly remains. But the style is quite different, and looking after the state yourself requires a certain amount of boilerplate. According to this PR, Flux should continue to offer a tidier version, which exploits mutation to update models & state objects.

The mutable state involves a new optimiser wrapper type, which is used for both explicit and implicit mode. Both modes use Optimisers.jl internally, so all the rule definitions in Flux.Optimise can be deleted. While many uses of the old train! will continue to work without modification, I think this is likely to be sufficiently breaking that it can only be in v0.14.

Example

A simple example that runs both modes, and works if you overload explicit_withgradient to use Diffractor instead of Zygote in that mode:

using Flux, Random
data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!;

# This exact code works on [email protected]. There, train! returns nothing:
model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
opt2 = Flux.Adam()
Flux.train!(Flux.params(model2), data, opt2) do x, y
  Flux.mse(model2(x), y)
end
opt2  # contains an IdDict

# This is the new "explicit" method of Train
model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
opt1 = Flux.Adam()
Flux.train!(model1, data, opt1) do m, x, y
  Flux.mse(m(x), y)
end |> sum
opt1  # contains state tree

# This changes the AD used:
import Diffractor
function Flux.Train.explicit_withgradient(f, args...)
  y, back = Diffractor.∂⃖¹(f, args...)
  @info "used Diffractor!"
  return (; value = y, gradient = Base.tail(back(one(y))))
end

# This is new 3-arg train!, one step not an iteration over data:
x1, y1 = data[1]
Flux.train!(model1, opt1) do m
  Flux.mse(m(x1), y1)
end

Checklist

  • [ ] Cover all the optimisation rules
  • [ ] More tests!
  • [x] Entry in NEWS.md
  • [ ] Many documentation changes

mcabbott avatar Jul 28 '22 20:07 mcabbott

what's the status now?

zsz00 avatar Sep 17 '22 15:09 zsz00

Nothing has moved really. But after reading the docs a bit, I think some variant of train! ought to survive the move away from implicit parameters.

A simpler take than this PR's present state might be to remove implicit Flux.params etc. completely, and merge something like https://github.com/FluxML/Optimisers.jl/pull/106 which will, as a side-effect, let train! update the optimiser state in-place. Then there would still be a new step like opt = Flux.setup(model, Adam()) required, but the resulting opt would not need to be explicitly passed around.

mcabbott avatar Sep 17 '22 17:09 mcabbott

Closing in favour of https://github.com/FluxML/Flux.jl/pull/2082

Adding more code to deal with implicit parameters in new ways doesn't seem great.

mcabbott avatar Oct 16 '22 18:10 mcabbott