[BUG] Truncation not handled properly in n-step bootstrapped returns
Describe the bug
First off, thanks for the amazing codebase! I'm a big fan of the pure Jax approach.
I think most places in the code that use bootstrapping aren't handling episode truncation properly. e.g. in ff_mz.py we have:
r_t = sequence.reward[:, :-1]
d_t = 1.0 - sequence.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
d_t = d_t[:, :-1]
search_values = sequence.search_value[:, 1:]
value_targets = batch_n_step_bootstrapped_returns(
r_t, d_t, search_values, config.system.n_steps
)
Note that sequence.done is set to timestep.last() which in turn checks if the step type is TRUNCATED or TERMINATED.
Let's consider what happens for a sequence of three timesteps (for a single env) corresponding to StepTypes (MID, TRUNCATED, FIRST). The arguments to batch_n_step_bootstrapped_returns might look like the following:
r_t = jnp.asarray([[1.0, 1.0, 0.0]])
discount_t = d_t = jnp.asarray([[0.99, 0.0, 0.99]])
v_t = search_values = jnp.asarray([[6.0, 5.0, 10.0]])
Let's say n_steps = 1. Then the value of the "MID" timestep should actually be 1.0 + 5.0 = 6.0. But here since we pass a discount of zero, the value estimate isn't captured, and the target value ends up being just 1.0.
To Reproduce
It's pretty deep in the code, so I haven't gotten to write a unit test. But if you plug the above values into batch_n_step_bootstrapped_returns I indeed get Array([[6.94, 1. , 9.9 ]], dtype=float32).
Expected behavior
In the above test case, the second element of the response (the value target for the MID state) should be 6.0 as described above.
Context (Environment)
I'm on commit fe9de0a. Running on MacOS 15.6.1. I don't think the full context is relevant for this issue.
Additional context
None
Possible Solution
In the ExItTransition, we should include the original discount. Then we can pass d_t as the lambda_t parameter.
r_t = sequence.reward[:, :-1]
d_t = 1.0 - sequence.done[:, :-1].astype(jnp.float32)
discount_t = sequence.discount[:, :-1] * config.system.gamma
search_values = sequence.search_value[:, 1:]
value_targets = batch_n_step_bootstrapped_returns(
r_t, discount_t, search_values, config.system.n_steps, d_t
)
Hey, yeah i am aware of this actually. See https://github.com/EdanToledo/Stoix/issues/142 for GAE calc. Basically, i fixed it for GAE calculation but it requires a specific way of inputting the data. Due to this, not every system handles truncation but its pretty easy to modify a system to handle it appropriately. I have not made the fix for the bootstrapped returns hence why it does not have a truncation flag. I would like to fix it though and make it similar to GAE so that we can handle both truncated scenarios and non truncated scenarios as well.
If we use the discount as the lambda value, i am not sure it will fix it for all scenarios but its worth checking it out. I actually wanted to write test cases for all of these functions but i haven't found the time
Oh awesome, sorry I missed that thread! I see, I'll get back to you once I try the change I suggested above.
EDIT: I checked out the way PPO handles truncation and, ah, it's a bit more subtle than I realized. The solution I proposed only works if the critic's output on the truncated state is kept in the sequence, which I understand you've decided to avoid, to avoid "dummy" rewards/discounts on the initial timestep. If I understand correctly, in each _env_step, it evaluates the critic on the current observation and also the next observation (in timestep.extras["next_obs"]). This ends up evaluating the critic twice for most timesteps, with the same set of parameters. (I doubt the Jax compiler is smart enough to elide that.)
Yeah, its not the most optimal set up unfortunately but having dummy rewards and actions introduces a whole other can of worms and would require significant changes to all the other algorithms. Considering that simplicity for research sake is preferred in this codebase, i believe that cost is worth it but i didnt want to make that change to all algorithms either thats why only PPO pretty much handles truncation currently in stoix. Let me know your thoughts.
For reference, I'm using the dm_env style, dummy reward/discount/action in another codebase, and as far as I can tell, setting reward=0, discount=0 on initial timesteps deals with bootstrapping, and the only modification is to mask out the policy loss for dummy actions. I think at least, if next_obs_in_extras=False in stoa.core_wrappers.auto_reset.AutoResetWrapper, it should return the terminal timestep.