deepmind-research icon indicating copy to clipboard operation
deepmind-research copied to clipboard

[RL Unplugged] How to recover the sequential data

Open IcarusWizard opened this issue 4 years ago • 3 comments

Hi, I am interested in testing my algorithm on dm_control dataset as benchmark and I am trapped in data loading part.

Since dm_control is a set of sequential control tasks and I have seen in dm_control_suite_crr.ipynb the CRR algorithm is using LSTM, I assume the dataset must have structure to store the time order of the data. But, when I try the demos in the repository, the dataset seem to be shuffled and I only get tuples like (s, a, r, s', a').

I am not familiar with tf.data.Dataset, so could anyone hint me out about how to recover the sequential structure of the data? Or, the sequential structure is not stored by the dataset at all? Thanks in advance!

IcarusWizard avatar Nov 04 '20 08:11 IcarusWizard

I would also like to know this. It seems likely that the data is not stored chronologically though which is a pain.

indrasweb avatar Nov 06 '20 11:11 indrasweb

Here's something that will help.

import collections
import functools
import os
import reverb
import six
import tree 

import numpy as np
import tensorflow as tf

from acme import specs
from reverb import replay_sample

class DataLoader:
    def __init__(self, root:str, task:str, shards:int):
        self._root = root
        self._task = task
        self._shards = shards

    def load_data(self):
        def _build_sarsa_example(sequences):
            o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])
            o_t = tree.map_structure(lambda t: t[1], sequences['observation'])
            a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])
            a_t = tree.map_structure(lambda t: t[1], sequences['action'])
            r_t = tree.map_structure(lambda t: t[0], sequences['reward'])
            p_t = tree.map_structure(lambda t: t[0], sequences['discount'])

            info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
                                    probability=tf.constant(1.0, tf.float64),
                                    table_size=tf.constant(0, tf.int64),
                                    priority=tf.constant(1.0, tf.float64))
            return reverb.ReplaySample(info=info, data=(o_tm1, a_tm1, r_t, p_t, o_t, a_t))

        def _parse_seq_tf_example(example, uint8_features, shapes):
            """Parse tf.Example containing one or two episode steps."""
            def to_feature(key, shape):
                if key in uint8_features:
                    return tf.io.FixedLenSequenceFeature(
                        shape=[], dtype=tf.string, allow_missing=True)
                else:
                    return tf.io.FixedLenSequenceFeature(
                        shape=shape, dtype=tf.float32, allow_missing=True)

            feature_map = {}
            for k, v in shapes.items():
                feature_map[k] = to_feature(k, v)

            parsed = tf.io.parse_single_example(example, features=feature_map)

            observation = {}
            restructured = {}
            for k in parsed.keys():
                if 'observation' not in k:
                    restructured[k] = parsed[k]
                    continue

                if k in uint8_features:
                    observation[k.replace('observation/', '')] = tf.reshape(
                    tf.io.decode_raw(parsed[k], out_type=tf.uint8), (-1,) + shapes[k])
                else:
                    observation[k.replace('observation/', '')] = parsed[k]

            restructured['observation'] = observation

            restructured['length'] = tf.shape(restructured['action'])[0]

            return restructured

        task = dm_control_suite.ControlSuite(self._task)
        path = os.path.join(self._root, task.data_path)
        filenames = [f'{path}-{i:05d}-of-{self._shards:05d}' for i in range(self._shards)]
        file_ds = tf.data.Dataset.from_tensor_slices(filenames)
        example_ds = file_ds.interleave(
            functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),
            cycle_length=tf.data.experimental.AUTOTUNE,
            block_length=5)

        def map_func(example):
            example = _parse_seq_tf_example(example, uint8_features={}, shapes=task.shapes)
            return example
        
        example_ds = example_ds.map(map_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)

        example_ds = example_ds.map(
            _build_sarsa_example,
            num_parallel_calls=tf.data.experimental.AUTOTUNE)

        return example_ds

This will give you the dataset unshuffled.

If you don't want it to be in SARSA format, just delete lines:

        example_ds = example_ds.map(
            _build_sarsa_example,
            num_parallel_calls=tf.data.experimental.AUTOTUNE)

atlashugs avatar Dec 08 '20 10:12 atlashugs

I don't think recovering episodes is possible. Even within a single file, the elements were shuffled. However, it is possible to extract initial/ending transitions by checking the "step_type". 0 indicates start of the trajectory, 1 is middle, and 2 is the end according to https://github.com/deepmind/dm_env/blob/master/dm_env/_environment.py#L69 .

sunshineclt avatar Aug 04 '21 16:08 sunshineclt