stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

a question regarding RecurrentPPO

Open zhengzl18 opened this issue 2 years ago • 2 comments

Hi, I'm using RecurrentPPO to train an RecurrentActorCriticPolicy.

I noticed that when collecting rollouts data, the hidden states in LSTM at each time steps are also storaged in the rollout buffer:

rollout_buffer.add(
    self._last_obs,
    actions,
    rewards,
    self._last_episode_starts,
    values,
    log_probs,
    lstm_states=self._last_lstm_states,
)

Then in during the training progress, the values and the log_prob of the actions under the current policy are directly given by regarding rollout_data.lstm_states as inputs:

values, log_prob, entropy = self.policy.evaluate_actions(
    rollout_data.observations,
    actions,
    rollout_data.lstm_states,
    rollout_data.episode_starts,
)

Now that the lstm_states come directly from the buffer, rather than being computed from the start, doesn't that mean that the backpropagation though time procedure merely goes back for one step? More precisely, the require_grad property of rollout_data.lstm_states equals False, is that reasonable?

I'm not sure if there's anything I'm missing. Hope for some replies.

Thanks.

zhengzl18 avatar Sep 20 '22 12:09 zhengzl18

Now that the lstm_states come directly from the buffer, rather than being computed from the start, doesn't that mean that the backpropagation though time procedure merely goes back for one step?

those lstm states are only here to initialize the hidden states at the beginning of each sampled sequence: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L214-L220

the backprop through time happens for the rest of the sequence, see the buffer to know what is the shape of sampled obs: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L230-L231

araffin avatar Sep 20 '22 13:09 araffin

Now that the lstm_states come directly from the buffer, rather than being computed from the start, doesn't that mean that the backpropagation though time procedure merely goes back for one step?

those lstm states are only here to initialize the hidden states at the beginning of each sampled sequence:

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L214-L220

the backprop through time happens for the rest of the sequence, see the buffer to know what is the shape of sampled obs:

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L230-L231

Thx for your reply!

I see. But I'm still a little confused, because from my perspective, the sampled obs should be of the shape (batch_size, history_length, obs_dim), where history_length is a hyperparameter I can switch, so that the sampled obs contains batch_size sequences, each of length history_length. That why when I saw the sampled obs was of the shape (128, obs_dim), I just assumed that it was a single-frame data, and hence came to the original question.

But here the batch_size equals n_seq * max_length, and n_seq just comes from len(seq_start_indices) (which is always 1 in my case). It seems like there's not so much things I can do to properly choose the data shape.

I wonder why the sampling mechanism is designed like this?

zhengzl18 avatar Sep 21 '22 07:09 zhengzl18

But I'm still a little confused, because from my perspective, the sampled obs should be of the shape (batch_size, history_length, obs_dim),

Actually no, the main reason is that you want to keep a mini batch size constant (otherwise you will need to adjust the learning rate for instance). So, internally, it does something simple for the high-level idea but complex in practice: it will sample batch_size ordered observations, which means we need to specify when each sequence start/stops. If the sequence (here a sequence is one episode) is long enough, then n_seq=1 and all observations sampled are from the same sequence.

araffin avatar Sep 26 '22 12:09 araffin