keras
keras copied to clipboard
[Contributors Wanted] A `split_dataset` utility
It would be neat to have a utility to split datasets, somewhat similar to this utility in sklearn.
Example:
train_ds, val_ds = keras.utils.split_dataset(dataset, left_size=0.8, right_size=0.2)
Draft docstring (missing code examples, etc):
def split_dataset(dataset, left_size=None, right_size=None, shuffle=False, seed=None):
"""Split a dataset into a left half and a right half (e.g. training / validation).
Args:
dataset: A `tf.data.Dataset` object or a list/tuple of arrays with the same length.
left_size: If float, it should be in range `[0, 1]` range and signifies the fraction of the
data to pack in the left dataset. If integer, it signifies the number of samples
to pack in the left dataset. If `None`, it defaults to the complement to `right_size`.
right_size: If float, it should be in range `[0, 1]` range and signifies the fraction of the
data to pack in the right dataset. If integer, it signifies the number of samples
to pack in the right dataset. If `None`, it defaults to the complement to `left_size`.
shuffle: Boolean, whether to shuffle the data before splitting it.
seed: A random seed for shuffling.
Returns:
A tuple of two `tf.data.Dataset` objects: the left and right splits.
"""
Notes:
- When processing a
Dataset, it would first iterate over the dataset, put the samples in a list, then split the list and create two datasets from each side of the split list. If iterating over the dataset takes more than 10s (computed continuously while iterating), a warning should be printed that the utility is only meant for small datasets that fit in memory. - Shuffling is done optionally before splitting (on the list / arrays). Not sure if we should apply
shuffle()to the returned datasets - Prefetching should be auto-tuned on the returned datasets
- At least one of left_size, right_size should be specified.
- If both are specified, we should check that they are complementary. If not that's an error.
- Feel free to suggest changes / additions to the API!
Interested in contributing it? Please open a PR or comment here for questions / suggestions!
Hello. I can try sending a draft code in some time. Maybe a Colab notebook and we can discuss from there?
I have a question, why should we iterate over the dataset ? why can't we just use left=dataset.take(len(dataset)*left_size) and right=dataset.skip(len(dataset)*left_size) ? I implemented a draft version of it and It seems to be working fine.
Did I misunderstand something? ? @fchollet
I have a question, why should we iterate over the dataset ? why can't we just use
Because this is inefficient. When you do dataset.skip() you do not create a new dataset that only contains the samples you're looking for, you create a view of the first dataset that will have to iterate over the first N samples every time you run through it. It's better to extract the samples to memory, then repackage them in a new dataset.
So do you want something like this?
https://www.tensorflow.org/datasets/splits
No, that's not what we're looking for. Thanks for the pointer, though.
No, that's not what we're looking for. Thanks for the pointer, though.
Yes I've used like.. as that one is working with TFDS and not with tf.data as you want here but I think about the performance this is still valid:
https://github.com/tensorflow/tensorflow/issues/44008#issuecomment-719740026
Because this is inefficient. When you do dataset.skip() you do not create a new dataset that only contains the samples you're looking for, you create a view of the first dataset
Is this not going to allocate a new dataset object?
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/data/skip_dataset_op.cc#L229:L233
Is this closed as per #16398?
This can be closed now as we have the API tf.keras.utils.split_dataset implemented and available now.