meta-dataset
meta-dataset copied to clipboard
Meta-Dataset in TFDS: Getting as_numpy_iterator() from dataset returned from api.meta_dataset takes a very long time
I am trying to use the new Meta-Dataset in TFDS APIs, and I have hit a critical performance problem.
When I run the sample code to "Train on Meta-Dataset episodes" (with some added lines to record times), it takes about 3 or 4 minutes to create the dataset api.meta_dataset
and roughly 1 hour to create the iterator episode_dataset.take(4).as_numpy_iterator()
. Here is the code I am running:
import gin
import meta_dataset
from meta_dataset.data.tfds import api
import tensorflow_datasets as tfds
import time
# Set up a TFDS-compatible data configuration.
gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
'learn/gin/setups/data_config_tfds.gin')
# 'v1' here refers to the Meta-Dataset protocol version and means that we
# are using the protocol defined in the original Meta-Dataset paper
# (rather than in the VTAB+MD paper, which is the 'v2' protocol; see the
# VTAB+MD paper for a detailed explanation). This is not to be confused
# with the (unrelated) arXiv version of the Meta-Dataset paper.
md_version = 'v2'
md_sources = ('aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', 'omniglot', 'quickdraw')
if md_version == 'v1':
md_sources += ('vgg_flower',)
print("\nStarting TFDS reader.")
start_time = time.time()
episode_dataset = api.meta_dataset(
md_sources,
md_version,
# This is where the meta-split ('train', 'valid', or 'test') is specified.
'train',
data_dir='<path to tensorflow datasets>'
)
t1 = time.time() # it takes about 4 minutes to get here
print("Created training dataset. Time = {0:.1f} seconds".format(t1 - start_time))
# We sample 4 episodes here for demonstration. `source_id` is the index (in
# `md_sources`) of the source that was sampled for the episode.
for episode, source_id in episode_dataset.take(4).as_numpy_iterator():
stop_time = time.time() # it takes about an hour to get here.
print("Created training iterator. Time = {0:.1f} seconds".format(stop_time - t1))
print("Total reader initialization time = {0:.1f} seconds".format(stop_time - start_time), flush=True)
support_images, support_labels, _ = episode[:3]
query_images, query_labels, _ = episode[3:]
I have run this on Linux and Windows with similar results. The time seems to be spent in:
_result = pywrap_tfe.TFE_Py_FastPathExecute(_ctx, "MakeIterator", name, dataset, iterator)
in the file gen_dataset_ops.py file which drops into C++ code that I didn't debug into.
Note that creating an iterator on api.episode_dataset
for evaluation is reasonably quick - omniglot takes the longest at about 3 minutes, but the others take only a few seconds.
This issue makes training on MDv2 from TFDS more or less impossible.
I believe most of that time is spend reading data and filling shuffle buffers. Each training class in each training source is instantiated as its own dataset with its own shuffle buffer (this is how examples are sampled from specific classes to form episodes), and by default in learn/gin/setups/data_config_tfds.gin
the shuffle buffer size is upper-bounded by 1000.
What happens if you set DataConfig.shuffle_buffer_size = 10
and DataConfig.num_prefetch = 1
in the configuration file? Does it speed up iterator creation? For training you could strike a balance between shuffling quality and startup time by changing the default values. For evaluation I would advise against changing the default values, as they could impact the data distribution of sampled episodes and introduce an unwanted confounding factor when comparing against competing approaches.
Thanks for looking into this. I did make the change that you suggested (DataConfig.shuffle_buffer_size = 10
and DataConfig.num_prefetch = 1
) and that did not help (other than creating the training dataset is somewhat faster). Here are the timings that I get:
Starting TFDS reader. Created training dataset. Time = 143.3 seconds (got a bit faster) Created training iterator. Time = 3347.3 seconds (this takes close to 1 hour!) Created validation iterators omniglot: Time = 5.7 seconds Created validation iterators aircraft: Time = 0.9 seconds Created validation iterators cu_birds: Time = 1.7 seconds Created validation iterators dtd: Time = 0.5 seconds Created validation iterators quickdraw: Time = 3.4 seconds Created validation iterators fungi: Time = 20.2 seconds Created validation iterators mscoco: Time = 2.7 seconds Total validation iterator creation time. Time = 35.1 seconds Created test iterator: omniglot, Time = 155.0 seconds Created test iterator: aircraft, Time = 0.9 seconds Created test iterator: cu_birds, Time = 1.7 seconds Created test iterator: dtd, Time = 0.5 seconds Created test iterator: quickdraw, Time = 3.4 seconds Created test iterator: fungi, Time = 22.5 seconds Created test iterator: traffic_sign, Time = 2.6 seconds Created test iterator: mscoco, Time = 2.3 seconds Total test iterator creation time = 189.0 seconds Total reader initialization time = 3714.6 seconds
Thus creating an iterator with a single dataset is acceptably fast, but creating an iterator over multiple datasets (so you can meta-train on MDv2) is unacceptably slow. As I mentioned above, the time seems to be spent in : _result = pywrap_tfe.TFE_Py_FastPathExecute(_ctx, "MakeIterator", name, dataset, iterator)
which is tricky to debug into.
This is a major blocker for us. If I can help debug in any way, I would be happy to.
John