robosuite icon indicating copy to clipboard operation
robosuite copied to clipboard

`flush_freq` in `data_collect_wrapper.py` affecting deterministic episode playback

Open jren03 opened this issue 11 months ago • 0 comments

I've noticed some divergence in the recorded and playback states correlated to the flush_freq parameter in DataCollectionWrapper (and have viewed relevant Issues/PRS [1, 2]).

Below is the script adapted from robosuite/demos/demo_collect_and_playback_data.py, in which I collect demonstrations, playback actions, and compare the L2 distance between recorded and actual states.

In my tests, the assertion statement fails on the first state from the second state_file path, which contains the next batch of states of length flush_freq. Specifically, this assertion fails on state_*.npz files created when fail_freq <= 2 * args.timesteps. In other words, in the example below, the assert fails if I set flush_freq=500, but not when flush_freq=501.

import argparse
import os
import random
from glob import glob

import numpy as np

import robosuite as suite
from robosuite.wrappers import DataCollectionWrapper


def set_seed(seed: int):
    """Set random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)


def collect_random_trajectory(env, timesteps=1000):
    """Run a random policy to collect trajectories.
    """
    set_seed(0)
    env.reset()
    dof = env.action_dim

    for t in range(timesteps):
        action = np.random.randn(dof)
        env.step(action)


def playback_trajectory(env, ep_dir):
    """Playback data from an episode.
    """

    # first reload the model from the xml
    xml_path = os.path.join(ep_dir, "model.xml")
    with open(xml_path, "r") as f:
        xml = env.edit_model_xml(f.read())
        env.reset_from_xml_string(xml)
    env.sim.reset()

    state_paths = os.path.join(ep_dir, "state_*.npz")

    # read states back, load them one by one, and render
    t = 0
    set_seed(0)
    for state_file in sorted(glob(state_paths)):
        dic = np.load(state_file, allow_pickle=True)
        states = dic["states"]
        init_state = states[0]
        env.sim.set_state_from_flattened(init_state)
        env.sim.forward()

        actions = dic["action_infos"]
        for idx, act_info in enumerate(actions):
            recorded_state = states[idx]
            actual_state = env.sim.get_state().flatten()
            divergence = np.linalg.norm(recorded_state - actual_state)

            assert divergence < 1e-6, f"Divergence: {divergence} at step {t}"

            act = act_info["actions"]
            env.step(act)
            t += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--environment", type=str, default="Lift")
    parser.add_argument(
        "--robots",
        nargs="+",
        type=str,
        default="Panda",
        help="Which robot(s) to use in the env",
    )
    parser.add_argument("--directory", type=str, default="tmp/")
    parser.add_argument("--timesteps", type=int, default=1000)
    args = parser.parse_args()

    # create original environment
    env = suite.make(
        args.environment,
        robots=args.robots,
        ignore_done=True,
        use_camera_obs=False,
        has_renderer=False,
        has_offscreen_renderer=False,
        control_freq=20,
    )
    data_directory = args.directory

    # wrap the environment with data collection wrapper
    env = DataCollectionWrapper(env, data_directory, flush_freq=200)

    # collect some data
    print("Collecting some random data...")
    collect_random_trajectory(env, timesteps=args.timesteps)

    print("Playing back the data...")
    data_directory = env.ep_directory
    playback_trajectory(env, data_directory)

jren03 avatar Mar 18 '24 14:03 jren03