stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

RecurrentActorCriticPolicy Behaviour Not Clear

Open pasinit opened this issue 9 months ago • 1 comments

📚 Documentation

I am trying to understand how the RecurrentActorCriticPolicy works. Coming from an NLP background I am used to have tensors of the shape (batch_size, seq_len, feature_dim) as input to the LSTM (and optional starting hidden states). From what I am seeing, however, the LSTM implemented basically allows only to feed sequence of length 1 https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L198

In fact, by zipping features_sequence (with shape [seq_len, n_envs, feature_dims]) and episode_starts (with shape [n_envs, -1]), in the case of 1 environment, we only allow seq_len to be 1.

Is this intended and am I reading this correctly? Is the logic behind that since we keep propagating the state we are still happy with sequences of length 1?

Checklist

  • [X] I have checked that there is no similar issue in the repo
  • [X] I have read the documentation

pasinit avatar May 09 '24 10:05 pasinit

tensors of the shape (batch_size, seq_len, feature_dim) as input to the LSTM (

that's correct

I think you missed: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L189-L194 here we pass a full sequence as input.

and for the rest, we unroll the sequence manually because we need to reset the state of the lstm when a new episode starts: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/25b43266e08ebe258061ac69688d94144799de75/sb3_contrib/common/recurrent/policies.py#L202-L204

araffin avatar May 10 '24 13:05 araffin