Mocha.jl icon indicating copy to clipboard operation
Mocha.jl copied to clipboard

Making solvers more general

Open the-moliver opened this issue 11 years ago • 7 comments

I'm working on an implementation of RMSProp for Mocha, but I quickly ran into the issue of it needing more parameters than SGD (to control things like the exponential window, and adaptive step sizes) but there not being a good way to define them since the acceptable parameters in SolverParameters are hard coded. I see a couple of ways to deal with this.

  1. Create an abstract type SolverParameters, and have the SGDParameters and RMSPropParameters constructors be subtypes. This seems like it would break the least things.

  2. Get rid of SolverParameters all together and define solvers as below:

solver = SGD(max_iter=600*2000, regu_coef=0.0, mom_policy=MomPolicy.Linear(0.5, 0.0008, 600, 0.9), lr_policy=LRPolicy.Step(0.1, 0.998, 600), load_from=base_dir)

This would break more things, but may simplify things overall. If you can think of a better way to deal with this, I'd love to hear.

A related issue is the updates in the general solver loop: solver_state.learning_rate = get_learning_rate(solver.params.lr_policy, solver_state) solver_state.momentum = get_momentum(solver.params.mom_policy, solver_state)

Should these be moved into the individual solvers, eg SGD, to keep the solver loop maximally general? This would cleanly allow new solvers to have additional parameters that vary across iterations.

I can work on both these things, but wanted to discuss first.

the-moliver avatar Jan 22 '15 22:01 the-moliver

@the-moliver Thanks for the suggestions. I think SolverParameters is designed to hold all general parameters that is needed in all solvers. If a solver needs extra control parameters, I suggest that it defines its own extra parameters, like this:

solver = RMSProp(params, exponential_window=...)

Here params is still the general SolverParameters. What do you think?

As for the second issue, putting them in general solver loop is to encourage code re-use. However, since you have adaptive step sizes in RMSProp, it seems that this principle no longer applies here. I'm OK moving them to specific solvers. There is one minor issue. For example, for SGD, there are

  • SolverState which holds general states common for all solvers (like learning rate, momentum, etc.).
  • SGDInternalState which holds whatever internally needed to implement SGD.

While it is OK to move computing of some fields of SolverState to specific solvers, I think the best way to do it is to allow specific solver to define customized function to fill those values. As you could see from the beginning of the function function solve(solver::Solver, net::Net), we need to compute those values before entering the loop. This is because we might need to dump to solver state to a snapshot at iteration 0 if the user specified.

That being said, I would like to ask, is it possible to implement your adaptive stepsize as a learning rate policy? If that is possible, then things are much easier I think.

pluskid avatar Jan 23 '15 02:01 pluskid

I think your suggestion for how to add additional params is a good one, since it is pretty elegant and doesn't break things. I'll implement RMSProp like that.

I don't think its possible to implement the adaptive step size as a policy since its adaptive for each parameter and the current values depend on the training, i.e. learning rate values are scaled up for a parameter if the sign of the gradient remains constant across iterations. I'm still getting my head around all the code so I'll spend some more time trying to figure out the best way to do this and it seems like doing the updates in the InternalStates would be fine, but let me know if you have any ideas. Keeping those lines in the general solver loop also aligns with the structure of SolverParameters which is good, so more specific things are in each function. Also note that I was planning to implement the adaptive step sizes on top of an overall decaying step size:

ovarall_stepsize x (values_of_all_adaptive_step_sizes) x weight_update

Great work by the way!

the-moliver avatar Jan 23 '15 02:01 the-moliver

@the-moliver OK, cool! Thanks! Then I think it's better keep the learning-rate and momentum in general solver (at least for now) as the learning-rate computed there will be your overall_stepsize.

pluskid avatar Jan 23 '15 03:01 pluskid

I have been looking at implementing an alternative solver (Adam - http://arxiv.org/abs/1412.6980).

I also found that the current implementation of the Solver API assumes too much about the specific method. In particular, the LRPolicy and Momentum policy aren't needed for Adam, but I do need other parameters.

The Adam solver code isn't ready to be merged yet (it runs, but I'm still tuning and debugging it) - but I would be glad to hear comments on the suggested API change.

I have drafted a solution here - https://github.com/pluskid/Mocha.jl/compare/master...benmoran:48e6c7e8ecdf44710adaa9d70d09af77441e9e29 - it's similar to @the-moliver's suggestion above.

It adds two abstract types, SpecificSolverParameter and SpecificSolverState. Then SolverParameters and SolverState have members of these types. Then solvers.jl gets a lot smaller after we move the code handling LR and momentum into a separate solvers/sgd-base.jl file.

I'm wondering whether it would make sense to merge the solver-specific InternalState types with the SpecificSolverState as well. I didn't do this yet because I the solver state gets serialized, and the InternalState could be very large, but it feels a bit messy having so many State and Parameter types floating around in this version. Perhaps the solver implementations can just provide functions that control what should be serialized.

benmoran avatar Sep 07 '15 08:09 benmoran

@benmoran Thanks! Yes I have been thinking about the interface of the solver. Maybe it could be much easier if we make the SolverParameter more general. For example, making it a simple Dictionary of key-value pairs maybe with default values. What do you think?

As for the SolverState. I agree with you that have a lot of types makes it very messy. What do you mean when you say that the solver state could be very large for the Adam solver? Does it store some intermediate matrices? The serialized solver state is to allow resumed training from snapshot. I believe allowing the solver to decide which part should be serialized and which part does not need to. As the solver should know which part is needed to reconstruct the training environment.

pluskid avatar Sep 07 '15 16:09 pluskid

@pluskid, yes, the Adam solver state needs two extra blobs for each parameter blob (estimates of the first and second moments of the gradient). But it would be possible to resume training from scratch without saving them to the snapshot, you just have to build the estimates again over the next iterations.

I think it's a good idea to have SolverParameters just be a Dictionary, with defaults provided by each solver. But since each solver implementation also needs to provide functions to save and load state snapshots, format any statistics it might add for logging, etc., I still can't think of a better way than having a separate InternalSolverState class for each one. I think we need a SolverState type like this (presumably we'll always have an iteration number and objective value):

type SolverState{T<:InternalSolverState}
  iteration::Int
  obj_val::Float64
  internal::T
end

Then the solvers can provide functions dispatching on the types like snapshot(state::SolverState{AdamInternalSolverState}), etc.

I'll try rewriting my branch along these lines (using a Dictionary for SolverParameters and using only a single InternalSolverState type for each solver) unless you have any other suggestions.

benmoran avatar Sep 07 '15 20:09 benmoran

@benmoran I second this design. Specifically, for SolverParameter, I think each solver could provide

  • a function to initialize a default dictionary with default parameters
  • a function to check the validity of the user specified parameters (maybe optional, but could be useful)

and then yes, solver state like that and provide a hook to be called during saving and loading of snapshots.

pluskid avatar Sep 07 '15 20:09 pluskid