cleanrl icon indicating copy to clipboard operation
cleanrl copied to clipboard

Truncation not handled correctly when `optimize_memory_usage=True`

Open samlobel opened this issue 9 months ago • 0 comments

Problem Description

In, for example, dqn_atari.py the replay buffer is instantiated with the optimize_memory_usage=True flag. This makes the buffer only have one stored list for observations, and chooses next_obs=observations[i+1] when sampling. However, cleanrl does its own logic to handle this (if trunc: real_next_obs[idx] = infos["final_observation"][idx]). But optimize_memory_usage means that this change is not reflected in the stored/sampled data.

Checklist

Current Behavior

Instead of data.next_observation[i] being the correct next observation, when an episode is truncated the next observation is the first of the reset environment.

Expected Behavior

It should be the correct next observation.

Possible Solution

I'm guessing there's a way to make this work, but for now the easiest thing to do is set optimize_memory_usage to False.

Steps to Reproduce

Here's a minimal code example, where the important parts are directly cribbed from dqn_atari.py. Switching to optimize_memory_usage=False prevents the assertion error.

import gymnasium as gym

from stable_baselines3.common.buffers import ReplayBuffer
import stable_baselines3 as sb3
import numpy as np


def make_env(env_id, seed, idx, capture_video, run_name):    
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env

    return thunk


envs = gym.vector.SyncVectorEnv(
    [make_env("MountainCar-v0", i, i, False, "testing") for i in [0]]
)

obs, _ = envs.reset(seed=0)

rb = ReplayBuffer(
    1000,
    envs.single_observation_space,
    envs.single_action_space,
    "cpu",
    optimize_memory_usage=True,
    # optimize_memory_usage=False,
    handle_timeout_termination=False,
)

seen_obs_and_next = set()
for i in range(1000):
    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    next_obs, rewards, terminations, truncations, infos = envs.step(actions)
    # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
    real_next_obs = next_obs.copy()
    for idx, trunc in enumerate(truncations):
        if trunc:
            real_next_obs[idx] = infos["final_observation"][idx]

    rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
    for o, next_o in zip(obs, real_next_obs): # because vectorized env
        seen_obs_and_next.add( (tuple(o.tolist()), tuple(next_o.tolist())) )


data = rb.sample(10000)
for i in range(10000):
    o = data.observations[i]
    no = data.next_observations[i]
    assert (tuple(o.tolist()), tuple(no.tolist())) in seen_obs_and_next


samlobel avatar Apr 25 '24 22:04 samlobel