cleanrl
cleanrl copied to clipboard
Potential bugs in RecordEpisodeStatistics
Problem Description
Checklist
- [x] I have installed dependencies via
poetry install
(see CleanRL's installation guideline. - [x] I have checked that there is no similar issue in the repo.
- [x] I have checked the documentation site and found not relevant information in GitHub issues.
Current Behavior
In ppo_rnd_envpool.py
(also ppo_atari_envpoo.py
), the implementation of RecordEpisodeStatistics
will accumulate the rewards after time-limit truncation since the self.episode_returns
is only masked by info["terminated"]
. This means that in Atari, the returns of two independent rounds (i.e., one round ends when the agent loses all of its lives) will be accumulated if the previous round gets resets due to time-limit truncation.
The following is what I observe when training using envpool
with max_episode_steps=27000
(default value in envpool
).
Here is how I log (adapted from this line
for idx, d in enumerate(done):
log_rewards[idx].append(reward[idx])
if info["terminated"][idx]:
avg_returns.append(info["r"][idx])
print(f`Env {idx} finishes a round with length {info['l'][idx]} and score {info['r'][idx]})
log_rewards[idx] = []
Then there are the logs I got
Env 0 finishes a round with length 54012 and score 1900
...
Env 0 finishes a round with length 81016 and score 4900
It's problematic since info["l"][idx] should not exceed 27000. I checked that when the timestep hits 27000, the environment will be reset. This means the scores across two rounds are summed up.
Expected Behavior
Expect the game scores is the sum of rewards over all the lives in one round.
Possible Solution
Should we change this line) to:
self.episode_returns *= 1 - (infos["terminated"] | infos["TimeLimit.Truncated"])
Steps to Reproduce
Run the following script:
import gym
import numpy as np
import envpool
is_legacy_gym = True
# From: https://github.com/sail-sg/envpool/blob/main/examples/cleanrl_examples/ppo_atari_envpool.py
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
# get if the env has lives
self.has_lives = False
env.reset()
info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
if info["lives"].sum() > 0:
self.has_lives = True
print("env has lives")
def reset(self, **kwargs):
if is_legacy_gym:
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
else:
observations, _ = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations
def step(self, action):
if is_legacy_gym:
observations, rewards, dones, infos = super(
RecordEpisodeStatistics, self
).step(action)
else:
observations, rewards, term, trunc, infos = super(
RecordEpisodeStatistics, self
).step(action)
dones = term + trunc
self.episode_returns += infos["reward"]
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
all_lives_exhausted = infos["lives"] == 0
if self.has_lives:
self.episode_returns *= 1 - all_lives_exhausted
self.episode_lengths *= 1 - all_lives_exhausted
else:
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)
# From: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool.py
# class RecordEpisodeStatistics(gym.Wrapper):
# def __init__(self, env, deque_size=100):
# super().__init__(env)
# self.num_envs = getattr(env, "num_envs", 1)
# self.episode_returns = None
# self.episode_lengths = None
# def reset(self, **kwargs):
# observations = super().reset(**kwargs)
# self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
# self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
# self.lives = np.zeros(self.num_envs, dtype=np.int32)
# self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
# self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
# return observations
# def step(self, action):
# observations, rewards, dones, infos = super().step(action)
# self.episode_returns += infos["reward"]
# self.episode_lengths += 1
# self.returned_episode_returns[:] = self.episode_returns
# self.returned_episode_lengths[:] = self.episode_lengths
# self.episode_returns *= 1 - infos["terminated"]
# self.episode_lengths *= 1 - infos["terminated"]
# infos["r"] = self.returned_episode_returns
# infos["l"] = self.returned_episode_lengths
# return (
# observations,
# rewards,
# dones,
# infos,
# )
if __name__ == "__main__":
np.random.seed(1)
envs = envpool.make(
"UpNDown-v5",
env_type="gym",
num_envs=1,
episodic_life=True, # Espeholt et al., 2018, Tab. G.1
repeat_action_probability=0, # Hessel et al., 2022 (Muesli) Tab. 10
full_action_space=False, # Espeholt et al., 2018, Appendix G., "Following related work, experts use game-specific action sets."
max_episode_steps=30, # Set as 50 to hit timelimit faster
reward_clip=True,
seed=1,
)
envs = RecordEpisodeStatistics(envs)
num_episodes = 2
episode_count = 0
cur_episode_len = 0
cur_episode_return = 0
my_episode_returns = []
my_episode_lens = []
# Track episode returns here to compare with the ones recorded with `RecordEpisodeStatistics`
recorded_episode_returns = []
recorded_episode_lens = []
obs = envs.reset()
while episode_count < num_episodes:
action = np.random.randint(0, envs.action_space.n, 1)
obs, reward, done, info = envs.step(action)
cur_episode_return += info["reward"][0]
cur_episode_len += 1
print(f"Ep={episode_count}, EpStep={cur_episode_len}, Return={info['r']}, MyReturn={cur_episode_return}, Terminated={info['terminated']}, Timeout={info['TimeLimit.truncated']}, Lives={info['lives']}")
# info["terminated"] = True: Game over.
# info["TimeLimit.truncated"] = True: Timeout, the environment will be reset (so the episode return should be reset too)
if info["terminated"][0] or info["TimeLimit.truncated"][0]:
recorded_episode_returns.append(info["r"][0]) # Append the episode return recorded in `RecordEpisodeStatistics`
recorded_episode_lens.append(info["l"][0]) # Append the episode length recorded in `RecordEpisodeStatistics`
my_episode_returns.append(cur_episode_return)
my_episode_lens.append(cur_episode_len)
print(f"Episode {episode_count}'s length is {cur_episode_len} (terminated={info['terminated']}, timeout={info['TimeLimit.truncated']})")
episode_count += 1
cur_episode_return *= 1 - (info["terminated"][0] | info["TimeLimit.truncated"][0])
cur_episode_len *= 1 - (info["terminated"][0] | info["TimeLimit.truncated"][0])
for episode_idx in range(num_episodes):
print(f"Episode {episode_idx}'s return is supposed to be {my_episode_returns[episode_idx]}, but the wrapper `RecordEpisodeStatistics` gives {recorded_episode_returns[episode_idx]}")
print(f"Episode {episode_idx}'s len is supposed to be {my_episode_lens[episode_idx]}, but the wrapper `RecordEpisodeStatistics` gives {recorded_episode_lens[episode_idx]}")
You should see the output:
env has lives
Ep=0, EpStep=1, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=2, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=3, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=4, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=5, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=6, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=7, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=8, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=9, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=10, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=11, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=12, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=13, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=14, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=15, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=16, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=17, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=18, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=19, Return=[0.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=20, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=21, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=22, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=23, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=24, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=25, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=26, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=27, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=28, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=0, EpStep=29, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 0's length is 29 (terminated=[0], timeout=[ True])
Ep=1, EpStep=1, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=2, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=3, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=4, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=5, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=6, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=7, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=8, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=9, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=10, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=11, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=12, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=13, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=14, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=15, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=16, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=17, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=18, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=19, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=20, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=21, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=22, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=23, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=24, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=25, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=26, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=27, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=28, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=29, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=30, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[False], Lives=[5]
Ep=1, EpStep=31, Return=[20.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 1's length is 31 (terminated=[0], timeout=[ True])
Episode 0's return is supposed to be 10.0, but the wrapper `RecordEpisodeStatistics` gives 10.0
Episode 0's len is supposed to be 29, but the wrapper `RecordEpisodeStatistics` gives 29
Episode 1's return is supposed to be 10.0, but the wrapper `RecordEpisodeStatistics` gives 20.0
Episode 1's len is supposed to be 31, but the wrapper `RecordEpisodeStatistics` gives 60
See the above example's output:
Ep=0, EpStep=29, Return=[10.], MyReturn=10.0, Terminated=[0], Timeout=[ True], Lives=[5]
Episode 0's length is 29 (terminated=[0], timeout=[ True])
Ep=1, EpStep=1, Return=[10.], MyReturn=0.0, Terminated=[0], Timeout=[False], Lives=[5]
The return in the new episode (Ep=1) is not reset to zero but is carried from the return in the old episode. The expected behavior is to reset the return counter to zero upon timeout.
@vwxyzjn