stable-baselines3
stable-baselines3 copied to clipboard
[Feature Request] Store next observations and dones in RolloutBuffer
🚀 Feature
Add next_observations and dones fields to the RolloutBuffer and the DictRolloutBuffer classes, similar to how it is done in the ReplayBuffer class.
Motivation
Currently, on-policy algorithms don't store the next observations and dones fields in their buffer in the get_rollouts method. This is because these fields are not required by any of the algorithms in stable-baselines3. However, these fields are required to be stored in the buffer to implement the original variant of the AIRL algorithm in imitation.
Pitch
No response
Alternatives
No response
Additional context
No response
Checklist
- [X] I have checked that there is no similar issue in the repo
Add next_observations and dones fields to the RolloutBuffer and the DictRolloutBuffer classes, similar to how it is done in the ReplayBuffer class.
dones are stored in episode_starts (shifted by one) and next_observations can be retrieved using observations[i+1] (except for terminal obs)
Alternatives
why not implement a custom buffer for your use case? (and you can fill it using a callback or custom SB3 version)
. However, these fields are required to be stored in the buffer to implement the original variant of the AIRL algorithm in imitation.
do you have a code example of that?