[BUG] GAE calculation wrong (deltas calculation)
Describe the bug
In the GAE calculation, the deltas are calculated before the main _body function.
From the PPO paper: $\delta_t = r_{t+1} + \gamma_{t+1} * V(s_{t+1}) - V(s_t)$.
Using Stoix indexing: $\delta_t = r_t + \gamma_t * V(s_t) - V(s_{t-1})$.
For this values[1:] - values[:-1] is used, with values[t] corresponding to $V(s_{t-1})$.
There are now two cases: 1) episodes are truncated when too long, and 2) episodes finish (successfully) before max_steps or rollout_length.
- In truncated episodes,
truncation_mask[max_steps-1]is 0. Thus, the last delta of the episode (delta_t[max_steps-1]) is nulled using the truncation_mask. However,delta_t[max_steps-1]should have a non-zero value. - In the case where the episodes are not truncated i.e. they were successful, the truncation_mask is set to 1 everywhere. However, the delta_t calculation crosses the episode boundary.
This error in delta_t propagates into the advantage calculation as well. Since the advantages are wrong, the target_values , which are used to supervise the value function, are also wrong.
Delta calculation: https://github.com/EdanToledo/Stoix/blob/11dca7de421a9859e0cc213dd00e6e7ed7e6a205/stoix/utils/multistep.py#L70 Advantage calculation: https://github.com/EdanToledo/Stoix/blob/11dca7de421a9859e0cc213dd00e6e7ed7e6a205/stoix/utils/multistep.py#L74-L87 Target value calculation: https://github.com/EdanToledo/Stoix/blob/11dca7de421a9859e0cc213dd00e6e7ed7e6a205/stoix/utils/multistep.py#L89
To Reproduce
Steps to reproduce the behavior:
- Freshly clone the repo and install dependencies
- create a file
stoix/configs/env/navix/four_rooms.yamlwith the following content:
# ---Environment Configs---
env_name: navix
scenario:
name: Navix-FourRooms-v0
task_name: navix-fourrooms-v0
kwargs: {
max_steps: 37,
}
# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return
# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
- Set a jax breakpoint in
batch_truncated_generalized_advantage_estimationinstoix/utils/multistep.py. - Run
python stoix/systems/ppo/anakin/ff_ppo.py env=navix/four_rooms system.rollout_length=50 system.gamma=1 system.gae_lambda=1 - Since we are in the sparse rewards setting,
r_tcan only be 0 or 1 once per episode (check this by printingr_t).
(jdb) r_t.nonzero()
(Array([ 0, 1, 3, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24, 28, 30, 30, 31,
33, 36, 40, 41, 41, 45], dtype=int32), Array([416, 332, 772, 143, 591, 255, 60, 754, 399, 218, 720, 692, 939,
198, 179, 898, 594, 998, 24, 725, 654, 789, 945], dtype=int32))
(jdb) r_t[:,0]
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
(jdb) r_t[:,60]
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
target_values, which are in our case the returns i.e. cumulative rewards should also be either 0 for unsuccessful episodes or 1 for successful episodes. After printing this we see:
(jdb) target_values[:,0]
Array([-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
-2.7359262, -2.7359262, -2.7107651, -2.7107651, -2.7107651,
-2.7107651, -2.7107651, -2.7107651, -2.7107651, -2.7107651,
-2.7107651, -2.7107651, -2.7107651, -2.7107651, -2.7107651], dtype=float32)
(jdb) target_values[:,60]
Array([ 1. , 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 1. ,
1. , -2.6827507, -2.6827507, -2.6827507, -2.6827507,
-2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
-2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
-2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
-2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
-2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
-2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507], dtype=float32)
(cc @emergenz)
In truncated episodes, the episode's last advantage is also (wrongly) nullified using truncation_mask.
Some more (potentially useful) context here: https://github.com/p-doom/reward-redistribution/pull/10
Yeah, so i agree, i think there might be a mistake somewhere here. Especially regarding truncation.
Should we maybe add a suite of tests to ensure that we reach known results in common environments? I think the most obvious one is the 400 average return on breakout (https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/), but since minatar's breakout is different from the original one, we can't use the known breakout return. Robert has published gymnax results (https://github.com/RobertTLange/gymnax?tab=readme-ov-file#implemented-accelerated-environments-%EF%B8%8F), maybe we can start with those?
So I've actually been dying to do a set of benchmarks and standardised testing to ensure all future contributions dont hurt performance elsewhere by mistake. I have the compute to run such tests, i just need to set it all up. The problem with gymnax minatar is that it does not use truncation i don't think. We need to choose an environment where both truncation is present and the lack of truncation would massively impact performance. A good example of this is the brax environment but the problem there is that different versions of brax and different physics backends used massively impact performance. If you are interested in helping me set up a standardised benchmarking/testing regime, that would be a super amazing contribution. I also really want to start adding unit tests for the more isolate functions such as the GAE estimation etc.
Okay so I have made a PR to fix the GAE calc and added test cases copied and subsequently modified from RLax. Regarding the indexing, important to know is I am following Rlax's indexing/Sutton barto indexing. The PPO paper uses different indexing which may make it seem different.
Thanks Edan! Will have a look asap
Hi,
Sorry for the late reply - took some time to review the PR. We realized that 2. in the original issue actually is not the case.
- In the case where the episodes are not truncated, i.e., they were successful, the truncation_mask is set to 1 everywhere. However, the delta_t calculation crosses the episode boundary.
Nevertheless, 1. was true. Previously, the last delta of the episode (delta_t[max_steps-1]) was nulled using the truncation_mask. However, delta_t[max_steps-1] should have a non-zero value. From our understanding, it should be the bootstrapped value, i.e., the estimated return for the rest of the trajectory. However, now, after the PR, it is the value of the first state of the next episode. So now we are, in fact, crossing episode boundaries.
Let us know whether this makes sense.
hmm, i'm not sure i fully understand but this jumps out to me as an autoreset api potential issue. If you turn off navix autoreset and use the jumanji autoreset API that i have, do you still see this issue?
Basically our point is that the only logical change of the GAE calculation introduced by https://github.com/EdanToledo/Stoix/pull/148/ is the removal of the following line: https://github.com/EdanToledo/Stoix/blob/d07507419b5e0a98597612e758b52ea3165d91f9/stoix/utils/multistep.py#L71
Now this fixes issue number 1 from this issue. However it introduces a new issue, where, as Mihir described, the episode boundary is crossed during the delta_t calculation on truncated episodes.
This is not an issue for naturally terminated episodes (terminated due to the environment, not because max_steps was crossed) because discount_t will be zero for those in delta_t = r_t + discount_t * values[1:] - values[:-1]. This is not the case for truncated episodes.
Does that make sense to you?
ah i think i get it so this is due to the 'value' being bootstrapped on is actually the value of the first observation of the new episode? if thats the case, i have a fix for it i think. let me know if this is what you mean.
Yes, exactly!
i actually think i fixed this in a private version of stoix for a paper and then it never got transferred...
so something came up, i actually think the problem will involve more changes than i'd hoped but ill try get back to it later today. If you wanna give it a shot, right now the problem is essentially that we need to specify different bootstrap values from baseline values. Lets say we are bootstrapping, we never bootstrap from timestep t=0 (initial obs of episode), only ever from timestep t=1 so our bootstrap values can be a sequence of values looking as follows [v_1, v_2,..., v_{T-1}, v_T, v_1, v_2, v_3] where v_T is the value of the final observation, this can only be accessed in the timestep.extras['next_obs'] variable. This variable is simply equal to timestep.obs if done==False. So essentially we need to get our bootstrap values from critic(timestep.extras['next_obs']). For our baseline values all we need is [v_0, v_1, v_2, ..., v_{T-1}, v_0, v_1, v_2, ..., v_{T-1}] so we never need the final value of the final obs. In Summary, we would probably need to change the GAE function to take two arrays as input, bootstrap values and baseline values but we just need to think if this is okay for all cases. I hope this makes sense.
just to further put this here so i dont forget. This would an example of constructing the baseline vs boostrap values. baseline values go from 0...k-1 and bootstrap values go from 1...k however the key difference is that baseline values skip over k if k==T and bootstrap values skip over 0 if k>T.
i just need to think how we can correctly use this in the GAE calc.
state, timestep = env.reset(jax.random.PRNGKey(3))
baseline_values = []
bootstrap_values = []
baseline_ts = []
bootstrap_ts = []
t=0
for _ in range(20):
baseline_values.append(timestep.observation.agent_view)
baseline_ts.append(t)
action = jax.random.randint(jax.random.PRNGKey(0), shape=(1,), minval=0, maxval=5)
state, timestep = env.step(state, action)
t+=1
bootstrap_values.append(timestep.extras['next_obs'].agent_view)
bootstrap_ts.append(t)
if timestep.last():
t=0
print("")
print("Baseline values")
print(jnp.concat(baseline_values).flatten())
print(baseline_ts)
print("")
print("Bootstrap values")
print(jnp.concat(bootstrap_values).flatten())
print(bootstrap_ts)
Baseline values
[0. 1. 2. 3. 4. 5. 0. 1. 2. 3. 4. 5. 0. 1. 2. 3. 4. 5. 0. 1.]
[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1]
Bootstrap values
[1. 2. 3. 4. 5. 6. 1. 2. 3. 4. 5. 6. 1. 2. 3. 4. 5. 6. 1. 2.]
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2]
https://github.com/EdanToledo/Stoix/tree/fix/gae_calc i added the change here but i havent had time to check it thoroughly
Doing a test between the main branch and that branch got these results on 10 seeds per env and ant, halfcheetah, humanoid, hopper.
Doing the RLiable eval so this can be seen as approx 40 seeds.
Given confidence intervals being relatively overlapping i'd say we cant say it 100% makes a difference but given these results i'd say it definitely at least improves or is the same.
Is it possible to architecturally compare this against baseline results from e.g. ant, halfcheetah on libraries like stable baselines 3?
unfortunately brax is quite different to the mujoco envs that everyone used to use. Additionally, even the brax different versions and different physics backends change the results significantly. All this said, I am trying to now do these comparisons and set up everything to reproduce baseline results very explicitly. See the https://github.com/EdanToledo/Stoix/tree/feat/testing_framework branch. So once we have this GAE thing fixed, that branch will be the main focus.
Nice! Thanks for all the timely work! I think for a quick sanity-check, we can try to reproduce these results from PureJaxRL. (Though for that we would have to fix the GymnaxWrapper bug first)
I dont think cartpole or breakout have truncation and regarding those two results i have reproduced them in my own capacity
Also see results here: https://arxiv.org/pdf/2411.00666 if you go all the way to the appendix, there are results for each task and these results were generated using stoix's PPO. Although these baseline results were per task hyperparameter tuned immensely.
So to finally close this issue, I'm pretty sure the current GAE calculation is correct. We have provided both the old API and the new API so that if a user has to support truncation, they just need to ensure their state values and bootstrap values are correctly aligned. If they dont care about truncation/the env does not have it, they can use the old 'values' variable. Not every implementation supports truncation by default but this is fine as ultimately a researcher using this should potentially be aware of this. We can add a consideration in the readme at some point but for now im satisfied. Thanks guys for spotting this and let me know if you think there is still any issues.
Thanks for all the work!