Store terminal observations in AutoResetWrapper infos for value bootstrap
Hello,
Differentiating between truncation and termination of episodes is key to proper value estimation. Brax already exposes truncation and episode_done flags as part of infos which allows to make the distinction.
When using the AutoResetWrapper (which is desirable) and encountering a terminal state, the returned state (and therefore observation) is the first state of the new episode. However, when doing value estimation in such states, one wishes to compute the value of the last state for value bootstrap. For a terminal (not truncated) state, this poses no issue as the value is 0 and the state can be ignored, however for a truncated state one needs to predict the value of the said state (not the first state of the new episode). This is not possible at the moment as this state is never exposed.
I propose adding a simple obs_st field in the infos returned through the AutoResetWrapper, which exposes this info to users for correct value bootstrapping.
I have added the corresponding test, and all tests pass.
I am new to open source-contributions and I could not find a style guide or a linter for brax. Please tell me if there are any modifications to be made.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
I have signed the CLA.