HybridBackend
HybridBackend copied to clipboard
hybridbackend 0.6.0a2 version raise ValueError when ParquetDataset wrapped by parallel_interleave ops
Current behavior
When hb.data.ParquetDataset wrapped by tf.data.experimental.parallel_interleave ops, here is a ValueError: Field xxx (dtype=unkown
Expected behavior
hb.data.ParquetDataset wrapped by tf.data.experimental.parallel_interleave ops works as normally as hybridbackend-0.6.0a1 version .
System information
- GPU model and memory:Tesla T4 16G
- OS Platform: Ubuntu 18.04.5 LTS
- Docker version: 20.10.14
- GCC/CUDA/cuDNN version: gcc version 7.5.0 (Ubuntu 7.5.0-3ubuntu1~18.04)/CUDA Version: 11.4.2/cuDNN 8
- Python/conda version:
- TensorFlow/PyTorch version: 1.15.5+deeprec2201
- HybridBackend version: '0.6.0a2'
Code to reproduce
import tensorflow as tf
import hybridbackend.tensorflow as hb
from tensorflow.python.data.ops import dataset_ops
def make_initializable_iterator(ds):
r"""Wrapper of make_initializable_iterator.
"""
if hasattr(dataset_ops, 'make_initializable_iterator'):
return dataset_ops.make_initializable_iterator(ds)
return ds.make_initializable_iterator()
def parquet_map(record):
label = record.pop('label_play')
return record, label
# Read from a parquet file.
dataset = tf.data.Dataset.list_files([
'part-00000-d07256ce-4685-4d6c-a9ab-b507ffef206e-c000.snappy.parquet'
],
seed=1)
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
lambda x: hb.data.ParquetDataset(
x,
# drop_remainder=True,
batch_size=4,
num_parallel_reads=1,
fields=[
hb.data.DataFrame.Field('uid', tf.int64),
hb.data.DataFrame.Field('packagename', tf.int64, ragged_rank=0),
hb.data.DataFrame.Field('recent_play_3', tf.int64, ragged_rank=1),
hb.data.DataFrame.Field('label_play', tf.float64),
],
),
cycle_length=1,
block_length=1,
))
ds = dataset.prefetch(4)
iterator = make_initializable_iterator(ds)
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
with tf.Session(config=sess_config) as sess:
sess.run(iterator.initializer)
for i in range(1):
feature = sess.run(iterator.get_next())
print(feature)
You can download the toy dataste from here
Willing to contribute
Yes