keras-cv icon indicating copy to clipboard operation
keras-cv copied to clipboard

Adding `Dataset` and `DataLoader`-like functionality in KerasCV

Open AdityaKane2001 opened this issue 2 years ago • 5 comments

@sayakpaul @LukeWood @ianjjohnson

Dataset loading and preprocessing in vision is generally a messy ordeal and the additional boilerplate code for TFRecords makes it worse. A probable solution for this can be to have classes which consolidate all preprocessing logic. Another set of classes can manage data augmentation. Concretely, I propose to have two kinds of classes: one for preprocessing and one for augmentation.

I understand that the code is a bit crude, but the idea is to provide an interface for users to consolidate all of their preprocessing and augmentation logic in one (or two) classes. We can extend this by providing frequently used APIs like TFRecordLoader and ImageDirectoryLoader.

Sample code:

class DataLoader:
	"""
	This class will _create_ a `tf.data.Dataset` instance. 
	"""
	def __init__(self, source):
		self.source = source # source can be anything that user wants. For example, it can be path to a directory, or a list of paths to TFRecords, etc.

	def get_dataset(self):
		# extract data from the source in whichever way necessary
		# For example, logic to read TFRecords can be written here.
		raise NotImplementedError


class DataAugmenter: # Open to suggestions for a better name :)
        """
	This class will _consume_ a `tf.data.Dataset` instance. 
	"""
	def __init__(self, dataset): # tf.data.Dataset object
		self.dataset = dataset

	def augment(self, example):
		# all augmentation logic
		raise NotImplementedError

	def get_dataset(self, batch_size):
		self.dataset = self.dataset.map(self.augment)
		self.dataset = self.dataset.prefetch(AUTO)
		self.dataset = self.dataset.batch(batch_size, AUTO)
		return self.dataset

class TFRecordLoader(DataLoader): # 
	def __init__(self, list_of_tfrecord_paths, tfrecords_format):
		super().__init__(source=list_of_tfrecord_paths)
		self.tfrecords_format = tfrecords_format

	def get_dataset(self): # one of the APIs that we can provide
		files = tf.data.Dataset.list_files(self.tfrecs_filepath)
                ds = files.interleave(
                    tf.data.TFRecordDataset, num_parallel_calls=AUTO, deterministic=False
                )
                ds = ds.map(self.decode_example_fn, num_parallel_calls=AUTO) # decode_example_fn omitted here for the sake of brevity
                ds = ds.prefetch(AUTO)
                return ds

Please see https://github.com/keras-team/keras-cv/issues/78#issuecomment-1070468749.

AdityaKane2001 avatar Jul 22 '22 09:07 AdityaKane2001

The overall structure looks good to me. Of course, we can incorporate all the tf.data specific optimization tricks (reference), but for kicking off the discussion, this code is sufficient, I think.

I would suggest you also provide a rough implementation of the TFRecord decoder utility. Becuase those of us who are familiar with it are going to want to see it.

sayakpaul avatar Jul 22 '22 09:07 sayakpaul

Transparently: still working through the details of what we want to do with respect to data loading. I appreciate your design work here, and will definitely consider this in any design work.

LukeWood avatar Jul 22 '22 17:07 LukeWood

@LukeWood great, thanks! Please let me know when these details are internally finalized, I'll be happy to contribute anything pertaining to this.

AdityaKane2001 avatar Jul 22 '22 18:07 AdityaKane2001

+1 same here

sayakpaul avatar Jul 22 '22 20:07 sayakpaul

Cool, I’m beginning to flesh out the details of what I think we want to offer.

LukeWood avatar Jul 22 '22 21:07 LukeWood

For now let’s hold off on this and not expand our API too much. We have tf.data. We can do some loading like keras.datasets does but let’s keep it minimal!

LukeWood avatar Sep 22 '22 21:09 LukeWood

Makes sense 👍

AdityaKane2001 avatar Sep 22 '22 21:09 AdityaKane2001