stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Feature Request] Store next observations and dones in RolloutBuffer

Open taufeeque9 opened this issue 2 years ago • 1 comments
trafficstars

🚀 Feature

Add next_observations and dones fields to the RolloutBuffer and the DictRolloutBuffer classes, similar to how it is done in the ReplayBuffer class.

Motivation

Currently, on-policy algorithms don't store the next observations and dones fields in their buffer in the get_rollouts method. This is because these fields are not required by any of the algorithms in stable-baselines3. However, these fields are required to be stored in the buffer to implement the original variant of the AIRL algorithm in imitation.

Pitch

No response

Alternatives

No response

Additional context

No response

Checklist

  • [X] I have checked that there is no similar issue in the repo

taufeeque9 avatar Jan 11 '23 14:01 taufeeque9

Add next_observations and dones fields to the RolloutBuffer and the DictRolloutBuffer classes, similar to how it is done in the ReplayBuffer class.

dones are stored in episode_starts (shifted by one) and next_observations can be retrieved using observations[i+1] (except for terminal obs)

Alternatives

why not implement a custom buffer for your use case? (and you can fill it using a callback or custom SB3 version)

. However, these fields are required to be stored in the buffer to implement the original variant of the AIRL algorithm in imitation.

do you have a code example of that?

araffin avatar Jan 12 '23 09:01 araffin