envlogger
envlogger copied to clipboard
dataset with TFDS feature Sequence will crash when trying to load
dataset generation code:
from envlogger.backends.tfds_backend_writer import *
from envlogger.step_data import *
import numpy as np
import dm_env
import tensorflow as tf
from os.path import expanduser
NB_FRAME = 3
SHAPE = (3, 3, 3)
HOME = expanduser("~")
DIRECTORY = HOME + "/data/test"
"""
Composite FeatureConnector for a dict where each value is a list.
"""
# a sequence of images, NB_FRAME long
tfds_features = tfds.features.Sequence(tfds.features.Image(shape=SHAPE), length=NB_FRAME)
observation = np.zeros((NB_FRAME, ) + SHAPE, dtype="uint8")
ds_config = tfds.rlds.rlds_base.DatasetConfig(
name='test',
observation_info=tfds_features,
action_info=tf.float64,
reward_info=tf.float64,
discount_info=tf.float64 # default python type for 0.
)
writer = TFDSBackendWriter(data_directory=DIRECTORY,
split_name='train', # required
max_episodes_per_file=500,
ds_config=ds_config)
zero_float64 = 0.0 # np.array(0.0, dtype="float64")
# start episode
timestep = dm_env.restart(observation=observation)
data = StepData(timestep, zero_float64)
writer.record_step(data, True)
# transition episode
timestep = dm_env.transition(reward=zero_float64, observation=observation)
data = StepData(timestep, zero_float64)
writer.record_step(data, False)
# end episode
timestep = dm_env.termination(reward=zero_float64, observation=observation)
data = StepData(timestep, zero_float64)
writer.record_step(data, False)
# close
writer.close()
dataset reader code:
from os.path import expanduser
import tensorflow_datasets as tfds
import rlds
""" Parameters """
HOME = expanduser("~")
DIRECTORY = HOME + "/data/test"
# load the dataset
builder = tfds.builder_from_directory(DIRECTORY)
dataset = builder.as_dataset(split='all')
print("Nb episode: ", len(dataset))
# flatten dataset
dataset = dataset.flat_map(lambda episode: episode[rlds.STEPS])
nb_steps = rlds.transformations.episode_length(dataset).numpy()
print("Nb steps: ", nb_steps)
Error generated:
Exception has occurred: TypeError
Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'ragged_flat_values'
File "/home/omnid/dexnex/ws_dexnex/src/ros2-to-rlds/ros2-to-rlds/test/test_ds_load.py", line 12, in <module>
dataset = builder.as_dataset(split='all')
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'ragged_flat_values'
@sabelaraga
any thoughts?
I think this is related to a problem in TFDS https://github.com/tensorflow/datasets/issues/2243
What's kind of crazy is how google/dm want the robotics community to use open-x-embodiment, RLDS, and envlogger, and then don't even release a functioning tool which is compatible with the types of data that roboticists would want to capture