robosuite
robosuite copied to clipboard
`flush_freq` in `data_collect_wrapper.py` affecting deterministic episode playback
I've noticed some divergence in the recorded and playback states correlated to the flush_freq
parameter in DataCollectionWrapper
(and have viewed relevant Issues/PRS [1, 2]).
Below is the script adapted from robosuite/demos/demo_collect_and_playback_data.py
, in which I collect demonstrations, playback actions, and compare the L2 distance between recorded and actual states.
In my tests, the assertion statement fails on the first state from the second state_file
path, which contains the next batch of states of length flush_freq
. Specifically, this assertion fails on state_*.npz
files created when fail_freq <= 2 * args.timesteps
. In other words, in the example below, the assert fails if I set flush_freq=500
, but not when flush_freq=501
.
import argparse
import os
import random
from glob import glob
import numpy as np
import robosuite as suite
from robosuite.wrappers import DataCollectionWrapper
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
def collect_random_trajectory(env, timesteps=1000):
"""Run a random policy to collect trajectories.
"""
set_seed(0)
env.reset()
dof = env.action_dim
for t in range(timesteps):
action = np.random.randn(dof)
env.step(action)
def playback_trajectory(env, ep_dir):
"""Playback data from an episode.
"""
# first reload the model from the xml
xml_path = os.path.join(ep_dir, "model.xml")
with open(xml_path, "r") as f:
xml = env.edit_model_xml(f.read())
env.reset_from_xml_string(xml)
env.sim.reset()
state_paths = os.path.join(ep_dir, "state_*.npz")
# read states back, load them one by one, and render
t = 0
set_seed(0)
for state_file in sorted(glob(state_paths)):
dic = np.load(state_file, allow_pickle=True)
states = dic["states"]
init_state = states[0]
env.sim.set_state_from_flattened(init_state)
env.sim.forward()
actions = dic["action_infos"]
for idx, act_info in enumerate(actions):
recorded_state = states[idx]
actual_state = env.sim.get_state().flatten()
divergence = np.linalg.norm(recorded_state - actual_state)
assert divergence < 1e-6, f"Divergence: {divergence} at step {t}"
act = act_info["actions"]
env.step(act)
t += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--environment", type=str, default="Lift")
parser.add_argument(
"--robots",
nargs="+",
type=str,
default="Panda",
help="Which robot(s) to use in the env",
)
parser.add_argument("--directory", type=str, default="tmp/")
parser.add_argument("--timesteps", type=int, default=1000)
args = parser.parse_args()
# create original environment
env = suite.make(
args.environment,
robots=args.robots,
ignore_done=True,
use_camera_obs=False,
has_renderer=False,
has_offscreen_renderer=False,
control_freq=20,
)
data_directory = args.directory
# wrap the environment with data collection wrapper
env = DataCollectionWrapper(env, data_directory, flush_freq=200)
# collect some data
print("Collecting some random data...")
collect_random_trajectory(env, timesteps=args.timesteps)
print("Playing back the data...")
data_directory = env.ep_directory
playback_trajectory(env, data_directory)