torch-autograd
torch-autograd copied to clipboard
Confusion with using optim with autograd
Thanks for the hard work on this, cool stuff.
I'm trying to use optim with autograd, but from looking at the examples, and previous issues I'm left a little confused.
- TO what extent is optim supported?
- Should I be using autograd.optim or just optim; is there any difference?
- The one example I have found uses
local optimfn, states = autograd.optim.sgd(df, state, params3)
, which has a different signature to the sgd in the normal optim packagesgd(opfunc, x, state)
. Is there some general rule here? - I believe optim expects flattened parameter vectors, whereas autograd supports nested tables of them. Is this different for
autograd.optim
If optim is supported, a paragraph and exampe in the readme would go a long way.
Looking at return src/optim/init.lua
it seems like function(fn, state, params)
is the general signature and it does supported non-flat params?
bump, would also like to know how optim was wrapped
It does support non-flat params. And so far this has been working well for me. @alexbw need your help to confirm this :)
-- this table is passed as reference and updated every time we run the optimizer function,
-- eg: store the states after each iteration.
local optimState = {
learningRate = 0.001
}
local params = {
module1 = { W, b },
module2 = { W, b }
}
local df = autograd(function(params, x, y) ... end)
-- for each batch do below
local x = ...
local y = ...
local feval = function(params) {
-- do something here
local grad, loss = df(params, x, y)
return grads, loss
}
local _,loss = autograd.optim.adam(feval, optimState, params)()
-- At this point, your `params` and `optimState` should be updated by the optim function.
-- Proceed to next batch or do the loss validation etc.
I'm not sure that's the right way of using it, It would be more efficient to call autograd.optim.adam() just once and reuse the returned function (because otherwise the optimState is deepCopy()'d over and over again to the states table). I'm surprised it works because the optimState shouldn't be updated like that.
I went over the tests and wrote a small review of the wrapper here
@ghostcow Thanks for pointing at the test code! My snippet above have worked in test set but I doubt it is entirely correct. Gonna fix it and try it again :)