Atari
Atari copied to clipboard
Recurrent Dqn
One central element of the Atari DQN is the use of 4 consecutive frames as input making the state more Markov, ie. having the vital dynamic movement information. This paper http://arxiv.org/abs/1507.06527v3 discusses DRQN: the multiframe input can be substituted with LSTM with the same effect (but no systematic advantage for one or the other). Also the Deepmind async paper mentions using LSTM instead of multi frame inputs for more challenging visual domains (Torcs and Labyrinth).
I think this would fit well in this codebase, I'll try to contribute this at one point.
Yep a switch for using a DRQN architecture would be great. For now I'd go for using histLen
as the number of frames to use BPTT on for a single-frame DRQN. Would be good to base it on the rnn
library, especially since it now has the optimised SeqLSTM
.
This is the Caffe implementation from the paper: https://github.com/mhauskn/dqn/tree/recurrent
Altough Caffe I never looked at probably will help.
@Kaixhin I see you started working on this, cool. I'll have some time now, so I'll look at the multigpu and async modes.
@lake4790k Almost have something working. Disabling this line lets the DRQN train, as otherwise it crashes here, somehow propagating a batch of size 20 forward but expecting the normal batch size of 32 backwards.
I'm new to the rnn
library, so let me know if you have any ideas. Performance is considerably slower, which will be due to having to process several time steps sequentially. This is in line with Appendix B in that paper though.
@Kaixhin Awesome! I have no experience with rnn
either, I will need to study it to have an idea. I have two 980TIs and will be able to run longer experiments to see if it goes anywhere.
@lake4790k I'd have to delve into the original paper/code, but it looks like they train the network every step (as opposed to every 4). This seems like it'll be a problem for BPTT. In any case if you haven't used rnn
before I'll focus on this.
@Kaixhin cool, I'll have my hands full with async for now, but in the meantime I'll be able to help with running longer rdqn experiments on my workstation when you think it's worth trying.
Here's the result of running ./run.sh demo -recurrent true
, so I'm reasonably confident that the DRQN is capable of learning, but I'm not testing this further for now so I'm leaving this issue open. In any case, I still haven't solved this issue (which I mentioned above).
Pinging @JoostvDoorn since he's contributed to rnn
and may have ideas about the minibatch problem/performance improvements/whether it's possible to save and restore state before and after training (and if that should be done since the parameters have changed slightly).
@Kaixhin I will have a look later.
@Kaixhin I'm not getting the error you mentioned when doing validation on the last batch with size 20 when running demo
. I'm using the master
code which has sequencer:remember('both')
enabled. You mention you had to disable that to not crash...? master
runs fine for me as it is.
I think this is in the rnn
branch. This may or may not be a bug when using FastLSTM with the nngraph version. Setting nn.FastLSTM.usenngraph = false
changed the error for me, but I only got the chance to look at this for a moment.
ok so there are two issues:
-
nn.FastLSTM.usenngraph = true
nngraph/gmodule.lua:335: split(4) cannot split 32 outputs
this is issue in bothrnn
andmaster
-
nn.FastLSTM.usenngraph = false
Wrong size for view. Input size: 20x1x3. Output size: 32x3
this is only inrnn
, because @Kaixhin fixed #16 inmaster
(but not inrnn
) that returns before doing thebackward
during validation, because it is not even needed, so maybe no issue after all?
- With
nn.FastLSTM.usenngraph = true
, I get the same error as @lake4790k. This seems to be https://github.com/Element-Research/rnn/issues/172. Which is a shame, as apparently it's significantly faster with this flag enabled (see https://github.com/Element-Research/rnn/issues/182). - Yes, so if you remove the
return
on line 374 inmaster
then it fails. So I consider this a bug, albeit one that is being hidden by that return - why is this occurring even whenstates
is20x4x1x24x24
andQCurr
is20x1x3
? If the error is dependent on previous batches then the learning must be incorrect. I was wrong and removingsequencer:remember('both')
doesn't stop the crash.
@Kaixhin re: 2. agree, this error is bad, so returning before is not a solution. I'm not sure if learning is bad with the normal batch sizes, could be only not handling a batch size change somewhere properly. I tried an isolated FastLSTM
+Sequencer
net, there switching batch sizes worked fine, weird. I'm looking adding LSTM to async, once I get that working will experiment with this further.
@lake4790k I also tried a simple FastLSTM
+ Sequencer
net with different batch sizes - no problem. I agree with it being likely that some module is not switching its internal variables to the correct size, but finding out exactly where the problem lies is tricky. It may be that I haven't set up the recurrency correctly, but apart from this batch size issue it seems to work fine.
@Kaixhin I need to refresh async
from master
for the recurrent, should I do a merge or rebase (I'm thinking of merge rather)? Does it even matter when merging back from async
to master
eventually?
@lake4790k I'd go with a merge since it preserves history correctly. It's better to make sure all the changes in master
are integrated sooner rather than later.
Done the merge and added recurrent
support for 1-step Q in async
. This is 7 minutes of training, seems to work well:
Agent sees only the latest frame per step and backpropagates with unrolling 5 steps on every step, weights are updated every 5 (or terminal) steps, no Sequencer
is needed in this algo. I used sharedRmsProp
and kept the ReLU
after the FastLSTM
to have comparable setup to my usual async
testing.
Pretty cool that is works, I'll try if it performs similar with a flickering catch as they did in the paper with the flickering pong. Also in the async paper they added a half size LSTM layer after the linear instead of replacing it, will try that as well (although the DRQN paper says replacing is the best).
Will add support for the n-step methods as well, there it's a bit trickier to get right as there are steps taken forwards and backwards to calculate n-step returns, will have to take care that forwards/backwards are correct for LSTM as well.
Also tried replacing FastLSTM
with GRU
with everything else being the same, that did not converge after running it longer interestingly.
@lake4790k Do you have the flickering catch version somewhere?
@JoostvDoorn haven't got around to it since, but probably takes a few lines to add to rlenvs.
@JoostvDoorn I can add that to rlenvs.Catch
if you want? You may also be interested in the obscured
option I set up, which blanks a strip of screen at the bottom so that the agent has to infer the motion of the ball properly. Quick enough to test by adding opt.obscured = true
in Setup.lua
.
@JoostvDoorn Done. Just get the latest version of rlenvs
and this repo. -flickering
is a probability between 0 and 1 of the screen blanking out.
@Kaixhin Great thanks.
Have you tried storing the state instead of calling forget for every time step? I am doing this now, however it takes longer to train but it will probably converge. I agree this has to do with the changing state distribution, but we cannot really let the agent explore without considering the history to take full advantage of the LSTM.
@JoostvDoorn I thought that this line would actually set remember
for all internal modules, but I'm not certain? If that is not the case then yes I agree it should be set on the LSTM units themselves.
In summy, in Agent:observe
, the only place that forget
is called is at a terminal state. Of course when learning it should call forget
before passing the minibatch through, and after learning as well. This means that memSampleFreq
is the maximum amount of history the LSTMs keep during training, but they receive the entire history during validation/evaluation.
@Kaixhin Yes that line is enough, I will change that in my pull request.
I missed memSampleFreq
, so I assumed it was calling forget every time. I guess memSampleFreq >= histLen
is a good thing here, such that training, and updating have a similar distribution. Do note though that the 5th action will update based on the 2th, 3th, 4th, and 5th state in the Q-learning update, while the policy followed will be only be based on the 5th state, right?
@JoostvDoorn Yep memSampleFreq >= histLen
would be sensible. Sorry not sure I understand your last question though. During learning updates for recurrent networks, histLen
is used to determine the sequence length of states fed in (no concatenating frames in time as with a normal DQN). During training the hidden state will go back until the last time forget
was called (and forget
is called every memSampleFreq
).
I guess like this; forget is called at the first time step so the LSTM will not have accumulated any information at this point, once here it will start accumulating state information (note however on torch.uniform() < epsilon
we don't accumulate info, which is a bug). Now after calling Agent:learn
we call forget again. Then once the episode continues, and reaches the point here the state information is the same as in the start of the episode, depending on the environment this is a problem.
Thanks for spotting the bug. @lake4790k please check 626712b to make sure async agents are accounted for as well.
@JoostvDoorn If I understand correctly then the issue is that the agent can't retain information during training because observe
is interspersed with forget
calls during learn
? That's what I was wondering about above. My reasoning comes from the rnn
docs. Also, it would be prohibitive to keep old states from before learn
and pass them all through the network before starting again.
@Kaixhin yes this is needed for async, just created #47 to do it a bit differently.