deepmind-research
deepmind-research copied to clipboard
[RL Unplugged] How to recover the sequential data
Hi, I am interested in testing my algorithm on dm_control
dataset as benchmark and I am trapped in data loading part.
Since dm_control
is a set of sequential control tasks and I have seen in dm_control_suite_crr.ipynb
the CRR algorithm is using LSTM, I assume the dataset must have structure to store the time order of the data. But, when I try the demos in the repository, the dataset seem to be shuffled and I only get tuples like (s, a, r, s', a').
I am not familiar with tf.data.Dataset
, so could anyone hint me out about how to recover the sequential structure of the data? Or, the sequential structure is not stored by the dataset at all? Thanks in advance!
I would also like to know this. It seems likely that the data is not stored chronologically though which is a pain.
Here's something that will help.
import collections
import functools
import os
import reverb
import six
import tree
import numpy as np
import tensorflow as tf
from acme import specs
from reverb import replay_sample
class DataLoader:
def __init__(self, root:str, task:str, shards:int):
self._root = root
self._task = task
self._shards = shards
def load_data(self):
def _build_sarsa_example(sequences):
o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])
o_t = tree.map_structure(lambda t: t[1], sequences['observation'])
a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])
a_t = tree.map_structure(lambda t: t[1], sequences['action'])
r_t = tree.map_structure(lambda t: t[0], sequences['reward'])
p_t = tree.map_structure(lambda t: t[0], sequences['discount'])
info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
probability=tf.constant(1.0, tf.float64),
table_size=tf.constant(0, tf.int64),
priority=tf.constant(1.0, tf.float64))
return reverb.ReplaySample(info=info, data=(o_tm1, a_tm1, r_t, p_t, o_t, a_t))
def _parse_seq_tf_example(example, uint8_features, shapes):
"""Parse tf.Example containing one or two episode steps."""
def to_feature(key, shape):
if key in uint8_features:
return tf.io.FixedLenSequenceFeature(
shape=[], dtype=tf.string, allow_missing=True)
else:
return tf.io.FixedLenSequenceFeature(
shape=shape, dtype=tf.float32, allow_missing=True)
feature_map = {}
for k, v in shapes.items():
feature_map[k] = to_feature(k, v)
parsed = tf.io.parse_single_example(example, features=feature_map)
observation = {}
restructured = {}
for k in parsed.keys():
if 'observation' not in k:
restructured[k] = parsed[k]
continue
if k in uint8_features:
observation[k.replace('observation/', '')] = tf.reshape(
tf.io.decode_raw(parsed[k], out_type=tf.uint8), (-1,) + shapes[k])
else:
observation[k.replace('observation/', '')] = parsed[k]
restructured['observation'] = observation
restructured['length'] = tf.shape(restructured['action'])[0]
return restructured
task = dm_control_suite.ControlSuite(self._task)
path = os.path.join(self._root, task.data_path)
filenames = [f'{path}-{i:05d}-of-{self._shards:05d}' for i in range(self._shards)]
file_ds = tf.data.Dataset.from_tensor_slices(filenames)
example_ds = file_ds.interleave(
functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),
cycle_length=tf.data.experimental.AUTOTUNE,
block_length=5)
def map_func(example):
example = _parse_seq_tf_example(example, uint8_features={}, shapes=task.shapes)
return example
example_ds = example_ds.map(map_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
example_ds = example_ds.map(
_build_sarsa_example,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return example_ds
This will give you the dataset unshuffled.
If you don't want it to be in SARSA format, just delete lines:
example_ds = example_ds.map(
_build_sarsa_example,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
I don't think recovering episodes is possible. Even within a single file, the elements were shuffled. However, it is possible to extract initial/ending transitions by checking the "step_type". 0 indicates start of the trajectory, 1 is middle, and 2 is the end according to https://github.com/deepmind/dm_env/blob/master/dm_env/_environment.py#L69 .