Using ShardingFilterIterDataPipe with MPRS may cause unnecessary batch drops.
🐛 Describe the bug
When using ShardingFilterIterDataPipe, the data in the datapipe will be evenly sharded to num_of_instances workers. However, if we called batch() later on the datapipe, the overly even distribution can cause workers to discard data that would not need to be discarded otherwise.
This might not be considered a bug, but it's kind of unexpected. Besides, the current ShardingFilterIterDataPipe will produce different batches of data for different number of workers, which is also kind of unexpected.
dp = torchdata.datapipes.iter.IterableWrapper(range(15)).sharding_filter().batch(5)
loader = DataLoader2(dp, reading_service=MultiProcessingReadingService(2))
for i in loader:
print(i)
loader.shutdown()
print("++++++++++++++++++++++++++++++++++++++++++++++")
loader = DataLoader2(dp, reading_service=MultiProcessingReadingService(1))
for i in loader:
print(i)
loader.shutdown()
This gives the following result:
[0, 2, 4, 6, 8]
[1, 3, 5, 7, 9]
[10, 12, 14] # These two batches will be dropped, if we set drop_last to True
[11, 13]
++++++++++++++++++++++++++++++++++++++++++++++
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]
One solution to this is to use a sharding filter that is aware of the batch size of the datapipe. Maybe something like the following:
class BatchShardingFilterIterDataPipe(torchdata.datapipes.iter.ShardingFilter):
def __init__(self, source_datapipe, sharding_group_filter=None):
super().__init__(source_datapipe, sharding_group_filter)
self.batch_size = 1
def set_batch_size(self, batch_size, drop_last):
self.batch_size = batch_size
def __iter__(self):
for i, batch_items in enumerate(
self.source_datapipe.batch(batch_size=self.batch_size, drop_last=False)
):
if i % self.num_of_instances == self.instance_id:
yield from batch_items
set_batch_size() needs to be called once the batch size is determined.
I wonder what do the torchdata team think of the current sharding filter. Is its behavior expected?
Versions
torch 2.0.0 torchaudio 2.0.0 torchdata 0.6.0