stable-baselines3-contrib
stable-baselines3-contrib copied to clipboard
a question regarding RecurrentPPO
Hi, I'm using RecurrentPPO to train an RecurrentActorCriticPolicy.
I noticed that when collecting rollouts data, the hidden states in LSTM at each time steps are also storaged in the rollout buffer:
rollout_buffer.add(
self._last_obs,
actions,
rewards,
self._last_episode_starts,
values,
log_probs,
lstm_states=self._last_lstm_states,
)
Then in during the training progress, the values
and the log_prob
of the actions under the current policy are directly given by regarding rollout_data.lstm_states
as inputs:
values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
rollout_data.lstm_states,
rollout_data.episode_starts,
)
Now that the lstm_states
come directly from the buffer, rather than being computed from the start, doesn't that mean that the backpropagation though time procedure merely goes back for one step? More precisely, the require_grad
property of rollout_data.lstm_states
equals False
, is that reasonable?
I'm not sure if there's anything I'm missing. Hope for some replies.
Thanks.
Now that the lstm_states come directly from the buffer, rather than being computed from the start, doesn't that mean that the backpropagation though time procedure merely goes back for one step?
those lstm states are only here to initialize the hidden states at the beginning of each sampled sequence: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L214-L220
the backprop through time happens for the rest of the sequence, see the buffer to know what is the shape of sampled obs: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L230-L231
Now that the lstm_states come directly from the buffer, rather than being computed from the start, doesn't that mean that the backpropagation though time procedure merely goes back for one step?
those lstm states are only here to initialize the hidden states at the beginning of each sampled sequence:
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L214-L220
the backprop through time happens for the rest of the sequence, see the buffer to know what is the shape of sampled obs:
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/7993b75781d7f43262c80c023cd83cfe975afe3a/sb3_contrib/common/recurrent/buffers.py#L230-L231
Thx for your reply!
I see. But I'm still a little confused, because from my perspective, the sampled obs should be of the shape (batch_size, history_length, obs_dim)
, where history_length
is a hyperparameter I can switch, so that the sampled obs contains batch_size
sequences, each of length history_length
. That why when I saw the sampled obs was of the shape (128, obs_dim)
, I just assumed that it was a single-frame data, and hence came to the original question.
But here the batch_size
equals n_seq * max_length
, and n_seq
just comes from len(seq_start_indices)
(which is always 1 in my case). It seems like there's not so much things I can do to properly choose the data shape.
I wonder why the sampling mechanism is designed like this?
But I'm still a little confused, because from my perspective, the sampled obs should be of the shape (batch_size, history_length, obs_dim),
Actually no, the main reason is that you want to keep a mini batch size constant (otherwise you will need to adjust the learning rate for instance).
So, internally, it does something simple for the high-level idea but complex in practice:
it will sample batch_size
ordered observations, which means we need to specify when each sequence start/stops.
If the sequence (here a sequence is one episode) is long enough, then n_seq=1
and all observations sampled are from the same sequence.