cleanrl
cleanrl copied to clipboard
Truncation not handled correctly when `optimize_memory_usage=True`
Problem Description
In, for example, dqn_atari.py
the replay buffer is instantiated with the optimize_memory_usage=True
flag. This makes the buffer only have one stored list for observations, and chooses next_obs=observations[i+1]
when sampling. However, cleanrl
does its own logic to handle this (if trunc: real_next_obs[idx] = infos["final_observation"][idx]
). But optimize_memory_usage
means that this change is not reflected in the stored/sampled data.
Checklist
- [Yes ] I have installed dependencies via
poetry install
(see CleanRL's installation guideline. - [Yes] I have checked that there is no similar issue in the repo.
- [Yes ] I have checked the documentation site and found not relevant information in GitHub issues.
Current Behavior
Instead of data.next_observation[i]
being the correct next observation, when an episode is truncated the next observation is the first of the reset environment.
Expected Behavior
It should be the correct next observation.
Possible Solution
I'm guessing there's a way to make this work, but for now the easiest thing to do is set optimize_memory_usage
to False.
Steps to Reproduce
Here's a minimal code example, where the important parts are directly cribbed from dqn_atari.py
. Switching to optimize_memory_usage=False
prevents the assertion error.
import gymnasium as gym
from stable_baselines3.common.buffers import ReplayBuffer
import stable_baselines3 as sb3
import numpy as np
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
return thunk
envs = gym.vector.SyncVectorEnv(
[make_env("MountainCar-v0", i, i, False, "testing") for i in [0]]
)
obs, _ = envs.reset(seed=0)
rb = ReplayBuffer(
1000,
envs.single_observation_space,
envs.single_action_space,
"cpu",
optimize_memory_usage=True,
# optimize_memory_usage=False,
handle_timeout_termination=False,
)
seen_obs_and_next = set()
for i in range(1000):
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
for o, next_o in zip(obs, real_next_obs): # because vectorized env
seen_obs_and_next.add( (tuple(o.tolist()), tuple(next_o.tolist())) )
data = rb.sample(10000)
for i in range(10000):
o = data.observations[i]
no = data.next_observations[i]
assert (tuple(o.tolist()), tuple(no.tolist())) in seen_obs_and_next