stable-baselines3-contrib
stable-baselines3-contrib copied to clipboard
RecurrentActorCriticPolicy Behaviour Not Clear
📚 Documentation
I am trying to understand how the RecurrentActorCriticPolicy works. Coming from an NLP background I am used to have tensors of the shape (batch_size, seq_len, feature_dim) as input to the LSTM (and optional starting hidden states). From what I am seeing, however, the LSTM implemented basically allows only to feed sequence of length 1 https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L198
In fact, by zipping features_sequence (with shape [seq_len, n_envs, feature_dims]
) and episode_starts (with shape [n_envs, -1]
), in the case of 1 environment, we only allow seq_len to be 1.
Is this intended and am I reading this correctly? Is the logic behind that since we keep propagating the state we are still happy with sequences of length 1?
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
tensors of the shape (batch_size, seq_len, feature_dim) as input to the LSTM (
that's correct
I think you missed: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L189-L194 here we pass a full sequence as input.
and for the rest, we unroll the sequence manually because we need to reset the state of the lstm when a new episode starts: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L202-L204