Design
Have been discussing with @pkofod how to design optimisers that can be used across Flux, Optim.jl and perhaps others. It seems the basic outline of a design in https://github.com/FluxML/Flux.jl/issues/637 is something that Optim can work with. We're currently looking at splitting this into:
state = init(rule, x)
dx', state' = apply(rule, x, dx, state)
x' = update(x, dx')
Some design goals from my side:
- It should be easy to e.g. specify that structs are optimised by optimising each field.
- It should be easy to specify how custom structs like
Colorsare updated (e.g. clamp the values). applyshould supportstate=nothingoptimisers in a generic way.- We also need an in-place
update!, but at this level we don't need to do any in-place/out-of-place detection. - Rules should be composable (e.g. weight decay and ADAM).
The current default for update(x, dx) is to calculate x .- dx; this is convenient for ML but could be changed if it's inconvenient for other things (we'll just do the negation as part of the rule).
How different is this design from the current optimisers.jl in Flux? We also need to use optimisers in our VI projects. Maybe we can help with this package?
It's designed to be essentially similar but with explicit state, rather than using IdDicts everywhere. I haven't really figured out how to make it convenient yet, though, so help is welcome. One advantage of this design is that it's flexible enough that e.g. state can actually be a whole history of states for L-BFGS and things like that, so it should be capable of expressing VI too, I would expect.