Stoix icon indicating copy to clipboard operation
Stoix copied to clipboard

[BUG] GAE calculation wrong (deltas calculation)

Open maharajamihir opened this issue 9 months ago • 21 comments

Describe the bug

In the GAE calculation, the deltas are calculated before the main _body function. From the PPO paper: $\delta_t = r_{t+1} + \gamma_{t+1} * V(s_{t+1}) - V(s_t)$. Using Stoix indexing: $\delta_t = r_t + \gamma_t * V(s_t) - V(s_{t-1})$. For this values[1:] - values[:-1] is used, with values[t] corresponding to $V(s_{t-1})$. There are now two cases: 1) episodes are truncated when too long, and 2) episodes finish (successfully) before max_steps or rollout_length.

  1. In truncated episodes, truncation_mask[max_steps-1] is 0. Thus, the last delta of the episode (delta_t[max_steps-1]) is nulled using the truncation_mask. However, delta_t[max_steps-1] should have a non-zero value.
  2. In the case where the episodes are not truncated i.e. they were successful, the truncation_mask is set to 1 everywhere. However, the delta_t calculation crosses the episode boundary.

This error in delta_t propagates into the advantage calculation as well. Since the advantages are wrong, the target_values , which are used to supervise the value function, are also wrong.

Delta calculation: https://github.com/EdanToledo/Stoix/blob/11dca7de421a9859e0cc213dd00e6e7ed7e6a205/stoix/utils/multistep.py#L70 Advantage calculation: https://github.com/EdanToledo/Stoix/blob/11dca7de421a9859e0cc213dd00e6e7ed7e6a205/stoix/utils/multistep.py#L74-L87 Target value calculation: https://github.com/EdanToledo/Stoix/blob/11dca7de421a9859e0cc213dd00e6e7ed7e6a205/stoix/utils/multistep.py#L89

To Reproduce

Steps to reproduce the behavior:

  1. Freshly clone the repo and install dependencies
  2. create a file stoix/configs/env/navix/four_rooms.yaml with the following content:
# ---Environment Configs---
env_name: navix
scenario:
  name: Navix-FourRooms-v0
  task_name: navix-fourrooms-v0

kwargs: {
  max_steps: 37,
}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
  _target_: stoix.wrappers.transforms.FlattenObservationWrapper
  1. Set a jax breakpoint in batch_truncated_generalized_advantage_estimation in stoix/utils/multistep.py.
  2. Run python stoix/systems/ppo/anakin/ff_ppo.py env=navix/four_rooms system.rollout_length=50 system.gamma=1 system.gae_lambda=1
  3. Since we are in the sparse rewards setting, r_t can only be 0 or 1 once per episode (check this by printing r_t).
(jdb) r_t.nonzero()
(Array([ 0,  1,  3,  9, 10, 11, 15, 16, 17, 18, 22, 23, 24, 28, 30, 30, 31,
       33, 36, 40, 41, 41, 45], dtype=int32), Array([416, 332, 772, 143, 591, 255,  60, 754, 399, 218, 720, 692, 939,
       198, 179, 898, 594, 998,  24, 725, 654, 789, 945], dtype=int32))
(jdb) r_t[:,0]
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
(jdb) r_t[:,60]
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)
  1. target_values, which are in our case the returns i.e. cumulative rewards should also be either 0 for unsuccessful episodes or 1 for successful episodes. After printing this we see:
(jdb) target_values[:,0]
Array([-2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7359262, -2.7359262, -2.7359262,
       -2.7359262, -2.7359262, -2.7107651, -2.7107651, -2.7107651,
       -2.7107651, -2.7107651, -2.7107651, -2.7107651, -2.7107651,
       -2.7107651, -2.7107651, -2.7107651, -2.7107651, -2.7107651],      dtype=float32)
(jdb) target_values[:,60]
Array([ 1.       ,  1.       ,  1.       ,  1.       ,  1.       ,
        1.       ,  1.       ,  1.       ,  1.       ,  1.       ,
        1.       ,  1.       ,  1.       ,  1.       ,  1.       ,
        1.       , -2.6827507, -2.6827507, -2.6827507, -2.6827507,
       -2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
       -2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
       -2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
       -2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
       -2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507,
       -2.6827507, -2.6827507, -2.6827507, -2.6827507, -2.6827507],      dtype=float32)

(cc @emergenz)

maharajamihir avatar Mar 13 '25 14:03 maharajamihir

In truncated episodes, the episode's last advantage is also (wrongly) nullified using truncation_mask.

Some more (potentially useful) context here: https://github.com/p-doom/reward-redistribution/pull/10

emergenz avatar Mar 13 '25 15:03 emergenz

Yeah, so i agree, i think there might be a mistake somewhere here. Especially regarding truncation.

EdanToledo avatar Mar 21 '25 18:03 EdanToledo

Should we maybe add a suite of tests to ensure that we reach known results in common environments? I think the most obvious one is the 400 average return on breakout (https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/), but since minatar's breakout is different from the original one, we can't use the known breakout return. Robert has published gymnax results (https://github.com/RobertTLange/gymnax?tab=readme-ov-file#implemented-accelerated-environments-%EF%B8%8F), maybe we can start with those?

emergenz avatar Mar 21 '25 19:03 emergenz

So I've actually been dying to do a set of benchmarks and standardised testing to ensure all future contributions dont hurt performance elsewhere by mistake. I have the compute to run such tests, i just need to set it all up. The problem with gymnax minatar is that it does not use truncation i don't think. We need to choose an environment where both truncation is present and the lack of truncation would massively impact performance. A good example of this is the brax environment but the problem there is that different versions of brax and different physics backends used massively impact performance. If you are interested in helping me set up a standardised benchmarking/testing regime, that would be a super amazing contribution. I also really want to start adding unit tests for the more isolate functions such as the GAE estimation etc.

EdanToledo avatar Mar 21 '25 19:03 EdanToledo

Okay so I have made a PR to fix the GAE calc and added test cases copied and subsequently modified from RLax. Regarding the indexing, important to know is I am following Rlax's indexing/Sutton barto indexing. The PPO paper uses different indexing which may make it seem different.

EdanToledo avatar Mar 22 '25 14:03 EdanToledo

Thanks Edan! Will have a look asap

emergenz avatar Mar 22 '25 15:03 emergenz

Hi,

Sorry for the late reply - took some time to review the PR. We realized that 2. in the original issue actually is not the case.

  1. In the case where the episodes are not truncated, i.e., they were successful, the truncation_mask is set to 1 everywhere. However, the delta_t calculation crosses the episode boundary.

Nevertheless, 1. was true. Previously, the last delta of the episode (delta_t[max_steps-1]) was nulled using the truncation_mask. However, delta_t[max_steps-1] should have a non-zero value. From our understanding, it should be the bootstrapped value, i.e., the estimated return for the rest of the trajectory. However, now, after the PR, it is the value of the first state of the next episode. So now we are, in fact, crossing episode boundaries.

Let us know whether this makes sense.

maharajamihir avatar Mar 24 '25 18:03 maharajamihir

hmm, i'm not sure i fully understand but this jumps out to me as an autoreset api potential issue. If you turn off navix autoreset and use the jumanji autoreset API that i have, do you still see this issue?

EdanToledo avatar Mar 24 '25 18:03 EdanToledo

Basically our point is that the only logical change of the GAE calculation introduced by https://github.com/EdanToledo/Stoix/pull/148/ is the removal of the following line: https://github.com/EdanToledo/Stoix/blob/d07507419b5e0a98597612e758b52ea3165d91f9/stoix/utils/multistep.py#L71

Now this fixes issue number 1 from this issue. However it introduces a new issue, where, as Mihir described, the episode boundary is crossed during the delta_t calculation on truncated episodes.

This is not an issue for naturally terminated episodes (terminated due to the environment, not because max_steps was crossed) because discount_t will be zero for those in delta_t = r_t + discount_t * values[1:] - values[:-1]. This is not the case for truncated episodes.

Does that make sense to you?

emergenz avatar Mar 25 '25 08:03 emergenz

ah i think i get it so this is due to the 'value' being bootstrapped on is actually the value of the first observation of the new episode? if thats the case, i have a fix for it i think. let me know if this is what you mean.

EdanToledo avatar Mar 25 '25 10:03 EdanToledo

Yes, exactly!

emergenz avatar Mar 25 '25 10:03 emergenz

i actually think i fixed this in a private version of stoix for a paper and then it never got transferred...

EdanToledo avatar Mar 25 '25 10:03 EdanToledo

so something came up, i actually think the problem will involve more changes than i'd hoped but ill try get back to it later today. If you wanna give it a shot, right now the problem is essentially that we need to specify different bootstrap values from baseline values. Lets say we are bootstrapping, we never bootstrap from timestep t=0 (initial obs of episode), only ever from timestep t=1 so our bootstrap values can be a sequence of values looking as follows [v_1, v_2,..., v_{T-1}, v_T, v_1, v_2, v_3] where v_T is the value of the final observation, this can only be accessed in the timestep.extras['next_obs'] variable. This variable is simply equal to timestep.obs if done==False. So essentially we need to get our bootstrap values from critic(timestep.extras['next_obs']). For our baseline values all we need is [v_0, v_1, v_2, ..., v_{T-1}, v_0, v_1, v_2, ..., v_{T-1}] so we never need the final value of the final obs. In Summary, we would probably need to change the GAE function to take two arrays as input, bootstrap values and baseline values but we just need to think if this is okay for all cases. I hope this makes sense.

EdanToledo avatar Mar 25 '25 14:03 EdanToledo

just to further put this here so i dont forget. This would an example of constructing the baseline vs boostrap values. baseline values go from 0...k-1 and bootstrap values go from 1...k however the key difference is that baseline values skip over k if k==T and bootstrap values skip over 0 if k>T.

i just need to think how we can correctly use this in the GAE calc.


state, timestep = env.reset(jax.random.PRNGKey(3))

baseline_values = []
bootstrap_values = []

baseline_ts = []
bootstrap_ts = []

t=0
for _ in range(20):
    
    baseline_values.append(timestep.observation.agent_view)
    baseline_ts.append(t)
    
    action = jax.random.randint(jax.random.PRNGKey(0), shape=(1,), minval=0, maxval=5)
    state, timestep = env.step(state, action)
    t+=1
    
    bootstrap_values.append(timestep.extras['next_obs'].agent_view)
    bootstrap_ts.append(t)
    
    if timestep.last():
        t=0
    
print("")
print("Baseline values")
print(jnp.concat(baseline_values).flatten())
print(baseline_ts)
print("")
print("Bootstrap values")
print(jnp.concat(bootstrap_values).flatten())
print(bootstrap_ts)
Baseline values
[0. 1. 2. 3. 4. 5. 0. 1. 2. 3. 4. 5. 0. 1. 2. 3. 4. 5. 0. 1.]
[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1]

Bootstrap values
[1. 2. 3. 4. 5. 6. 1. 2. 3. 4. 5. 6. 1. 2. 3. 4. 5. 6. 1. 2.]
[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2]

EdanToledo avatar Mar 25 '25 17:03 EdanToledo

https://github.com/EdanToledo/Stoix/tree/fix/gae_calc i added the change here but i havent had time to check it thoroughly

EdanToledo avatar Mar 26 '25 12:03 EdanToledo

Doing a test between the main branch and that branch got these results on 10 seeds per env and ant, halfcheetah, humanoid, hopper.

Doing the RLiable eval so this can be seen as approx 40 seeds.

Image

Image

Image

Image

Given confidence intervals being relatively overlapping i'd say we cant say it 100% makes a difference but given these results i'd say it definitely at least improves or is the same.

EdanToledo avatar Mar 26 '25 16:03 EdanToledo

Is it possible to architecturally compare this against baseline results from e.g. ant, halfcheetah on libraries like stable baselines 3?

emergenz avatar Mar 26 '25 16:03 emergenz

unfortunately brax is quite different to the mujoco envs that everyone used to use. Additionally, even the brax different versions and different physics backends change the results significantly. All this said, I am trying to now do these comparisons and set up everything to reproduce baseline results very explicitly. See the https://github.com/EdanToledo/Stoix/tree/feat/testing_framework branch. So once we have this GAE thing fixed, that branch will be the main focus.

EdanToledo avatar Mar 26 '25 17:03 EdanToledo

Nice! Thanks for all the timely work! I think for a quick sanity-check, we can try to reproduce these results from PureJaxRL. (Though for that we would have to fix the GymnaxWrapper bug first)

Image

emergenz avatar Mar 26 '25 17:03 emergenz

I dont think cartpole or breakout have truncation and regarding those two results i have reproduced them in my own capacity

EdanToledo avatar Mar 26 '25 17:03 EdanToledo

Also see results here: https://arxiv.org/pdf/2411.00666 if you go all the way to the appendix, there are results for each task and these results were generated using stoix's PPO. Although these baseline results were per task hyperparameter tuned immensely.

EdanToledo avatar Mar 26 '25 17:03 EdanToledo

So to finally close this issue, I'm pretty sure the current GAE calculation is correct. We have provided both the old API and the new API so that if a user has to support truncation, they just need to ensure their state values and bootstrap values are correctly aligned. If they dont care about truncation/the env does not have it, they can use the old 'values' variable. Not every implementation supports truncation by default but this is fine as ultimately a researcher using this should potentially be aware of this. We can add a consideration in the readme at some point but for now im satisfied. Thanks guys for spotting this and let me know if you think there is still any issues.

EdanToledo avatar Sep 16 '25 14:09 EdanToledo

Thanks for all the work!

emergenz avatar Sep 16 '25 16:09 emergenz