petastorm
petastorm copied to clipboard
[Tensorflow] Support tf.dataset.repeat() to avoid duplicating and dropping samples in one epoch with shuffle?
Current implementation of make_petastorm_dataset
for tensorflow doesn't support multiple iterations:https://github.com/uber/petastorm/blob/7f37e8dde6ff1b13f055d22a6289e2de8bb5d473/petastorm/tf_utils.py#L370
It is recommended to set reader's num_epochs
> 1 to support multiple iterations.
This will cause possible duplication and drop and samples in one epoch when using together with tf.dataset.shuffle
.
Let's say:
- My training rows are [0,1,2,3,4,5,6,7]
- I plan to train 10 epochs, so I set
reader.num_epochs = 10
. - I want to shuffle my samples as well in training, so I add
shuffle
to the dataset as below:
dataset = make_petastorm_dataset(reader)
dataset = dataset.shuffle(8)
- During training, due the implmentation of
shuffle
in tf: API docs (TF API page):
This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.
For instance, if your dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer. Once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buf
-
A training epoch will be like:
- Randomly pick 7, and refill with 0 (next available in reader) again: [0,1,2,3,4,5,6,0]
- Now it is highly possible 0 will be selected twice in next few iterations, even before other values are selected once.
- In the end of an epoch, it is highly possible 0 is duplicated, while some values are dropped.
-
This example is reproducible in our real applications, the reason is that by using
tf.dataset.shuffle
, it will always refill shuffle buffer with next available element from reader (since we set reader to be multiple epochs) -
A work-around to fix this:
- Set reader num_epochs to be 1.
- Enable multiple iterations in
make_petastorm_dataset
- Using shuffle and repeat from tf.dataset as:
dataset = make_petastorm_dataset(reader)
dataset = dataset.shuffle(8)
dataset = dataset.repeat(num_epochs)
- In this case,
repeat
isolates different epochs, so we are not seeing sample drop and duplication in a single epoch. (verified on our applications).
@selitvin What do you think?