data icon indicating copy to clipboard operation
data copied to clipboard

Using ShardingFilterIterDataPipe with MPRS may cause unnecessary batch drops.

Open yuxinyuan opened this issue 2 years ago • 0 comments

🐛 Describe the bug

When using ShardingFilterIterDataPipe, the data in the datapipe will be evenly sharded to num_of_instances workers. However, if we called batch() later on the datapipe, the overly even distribution can cause workers to discard data that would not need to be discarded otherwise.

This might not be considered a bug, but it's kind of unexpected. Besides, the current ShardingFilterIterDataPipe will produce different batches of data for different number of workers, which is also kind of unexpected.

dp = torchdata.datapipes.iter.IterableWrapper(range(15)).sharding_filter().batch(5)

loader = DataLoader2(dp, reading_service=MultiProcessingReadingService(2))
for i in loader:
    print(i)
loader.shutdown()
print("++++++++++++++++++++++++++++++++++++++++++++++")
loader = DataLoader2(dp, reading_service=MultiProcessingReadingService(1))
for i in loader:
    print(i)
loader.shutdown()

This gives the following result:

[0, 2, 4, 6, 8]
[1, 3, 5, 7, 9]
[10, 12, 14]  # These two batches will be dropped, if we set drop_last to True
[11, 13]
++++++++++++++++++++++++++++++++++++++++++++++
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]

One solution to this is to use a sharding filter that is aware of the batch size of the datapipe. Maybe something like the following:

class BatchShardingFilterIterDataPipe(torchdata.datapipes.iter.ShardingFilter):
    def __init__(self, source_datapipe, sharding_group_filter=None):
        super().__init__(source_datapipe, sharding_group_filter)
        self.batch_size = 1

    def set_batch_size(self, batch_size, drop_last):
        self.batch_size = batch_size

    def __iter__(self):
        for i, batch_items in enumerate(
            self.source_datapipe.batch(batch_size=self.batch_size, drop_last=False)
        ):
            if i % self.num_of_instances == self.instance_id:
                yield from batch_items

set_batch_size() needs to be called once the batch size is determined.

I wonder what do the torchdata team think of the current sharding filter. Is its behavior expected?

Versions

torch 2.0.0 torchaudio 2.0.0 torchdata 0.6.0

yuxinyuan avatar Jun 07 '23 08:06 yuxinyuan