Optimisers.jl
Optimisers.jl copied to clipboard
`reset!(optimiser_state)`
Motivation and description
In my application I do 25 steps of gradient descent update! steps in a loop (solving a differential equation). I need the momentum from the previous 25 GD steps to NOT carry over to the next 25 GD steps. In other words, the behavior I am looking for is analogous to calling Flux.setup(optimiser, model) every time. Unfortunately, Flux.setup is type-unstable https://github.com/FluxML/Optimisers.jl/issues/162. It would be great to have a function reset!(optimiser_state) that resets the momenta. Maybe a more stringent requirement is that
state = Flux.setup(optimiser, model)
# do some training
reset!(state)
state == Flux.setup(optimiser, model)
holds.
Possible Implementation
Below is an implementation for Adam.
function reset!(leaf::Leaf{A, S}) where {A <: Optimisers.Adam, S}
leaf.state[1] .= 0
leaf.state[2] .= 0
leaf.state = (leaf.state[1], leaf.state[2], leaf.rule.beta)
nothing
end
function reset!(state::NamedTuple{(:layers,), L}) where {L}
for layer in state.layers
reset!(layer.weight)
reset!(layer.bias)
end
nothing
end
One possible design is this:
reset!(tree) = foreach(reset!, tree)
reset!(ℓ::Leaf) = ℓ.state = reset!(ℓ.rule, ℓ.state)
reset!(::AbstractRule, ::Nothing) = nothing
reset!(rule::AbstractRule, state) = throw(ArgumentError("""reset! does not now how to handle this rule."))
Then rules need to opt-in by defining a method of 2-arg reset!... with some fill!! which allows for immutable arrays?
reset!(rule::Adam, (mt, vt, βt)) = (fill!!(mt, 0), fill!!(vt, 0), rule.beta)
We can't easily fall back to calling init again for unknown rules, as we don't have the original parameters x here.
Falling back to zero like this might be OK for built-in rules like Momentum etc, but could be wrong for user-defined rules... probably we shouldn't:
reset!(rule::AbstractRule, state::AbstractArray) = fill!!(state, 0)
We could always make reset! take the parameter tree of xs too, but that may come at the cost of sacrificing type stability.