`drop_last_batch` does not drop the last batch using IterableDataset + interleave_datasets + multi_worker
Describe the bug
See the script below
drop_last_batch=True is defined using map() for each dataset.
The last batch for each dataset is expected to be dropped, id 21-25.
The code behaves as expected when num_workers=0 or 1.
When using num_workers>1, 'a-11', 'b-11', 'a-12', 'b-12' are gone and instead 21 and 22 are sampled.
Steps to reproduce the bug
from datasets import Dataset
from datasets import interleave_datasets
from torch.utils.data import DataLoader
def convert_to_str(batch, dataset_name):
batch['a'] = [f"{dataset_name}-{e}" for e in batch['a']]
return batch
def gen1():
for ii in range(1, 25):
yield {"a": ii}
def gen2():
for ii in range(1, 25):
yield {"a": ii}
# https://github.com/huggingface/datasets/issues/6565
if __name__ == '__main__':
dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=2)
dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=2)
dataset1 = dataset1.map(lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=10, drop_last_batch=True)
dataset2 = dataset2.map(lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=10, drop_last_batch=True)
interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted")
print(f"num_workers=0")
loader = DataLoader(interleaved, batch_size=5, num_workers=0)
i = 0
for b in loader:
print(i, b['a'])
i += 1
print('=-' * 20)
print(f"num_workers=1")
loader = DataLoader(interleaved, batch_size=5, num_workers=1)
i = 0
for b in loader:
print(i, b['a'])
i += 1
print('=-' * 20)
print(f"num_workers=2")
loader = DataLoader(interleaved, batch_size=5, num_workers=2)
i = 0
for b in loader:
print(i, b['a'])
i += 1
print('=-' * 20)
print(f"num_workers=3")
loader = DataLoader(interleaved, batch_size=5, num_workers=3)
i = 0
for b in loader:
print(i, b['a'])
i += 1
output is:
num_workers=0
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
2 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
3 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
4 ['a-11', 'b-11', 'a-12', 'b-12', 'a-13']
5 ['b-13', 'a-14', 'b-14', 'a-15', 'b-15']
6 ['a-16', 'b-16', 'a-17', 'b-17', 'a-18']
7 ['b-18', 'a-19', 'b-19', 'a-20', 'b-20']
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
num_workers=1
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
2 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
3 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
4 ['a-11', 'b-11', 'a-12', 'b-12', 'a-13']
5 ['b-13', 'a-14', 'b-14', 'a-15', 'b-15']
6 ['a-16', 'b-16', 'a-17', 'b-17', 'a-18']
7 ['b-18', 'a-19', 'b-19', 'a-20', 'b-20']
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
num_workers=2
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['a-13', 'b-13', 'a-14', 'b-14', 'a-15']
2 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
3 ['b-15', 'a-16', 'b-16', 'a-17', 'b-17']
4 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
5 ['a-18', 'b-18', 'a-19', 'b-19', 'a-20']
6 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
7 ['b-20', 'a-21', 'b-21', 'a-22', 'b-22']
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
num_workers=3
Too many dataloader workers: 3 (max is dataset.num_shards=2). Stopping 1 dataloader workers.
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['a-13', 'b-13', 'a-14', 'b-14', 'a-15']
2 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
3 ['b-15', 'a-16', 'b-16', 'a-17', 'b-17']
4 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
5 ['a-18', 'b-18', 'a-19', 'b-19', 'a-20']
6 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
7 ['b-20', 'a-21', 'b-21', 'a-22', 'b-22']
Expected behavior
'a-21', 'b-21', 'a-22', 'b-22' should be dropped
Environment info
datasetsversion: 3.3.2- Platform: Linux-5.15.0-1056-aws-x86_64-with-glibc2.31
- Python version: 3.10.16
huggingface_hubversion: 0.28.0- PyArrow version: 19.0.0
- Pandas version: 2.2.3
fsspecversion: 2024.6.1
Hi @memray, I’d like to help fix the issue with drop_last_batch not working when num_workers > 1. I’ll investigate and propose a solution. Thanks!
Thank you very much for offering to help! I also noticed a problem related to a previous issue and left a comment here (the code checks the validity before certain columns removed). Can you take a look as well?
I looked into this and the problem here seems to be the order of sharding and batching/or how drop_last_batch is done (see the potential solutions below if unclear). Since we have 2 workers and 2 shards the data is split into 1-12 on worker 1 and 13-24 on worker 2. Now each of those workers iterates in batches of 10 and drops the last element, therefore worker 1 drops {11, 12} and worker 2 {23, 24}. There are multiple ways to circumvent that:
- distribute batches in turns to workers and tell workers if they should drop the batches individually, so that only the last worker drops anything
- distribute data as right now but telling each worker how many samples to drop individually (but that would require each worker to know how many samples they hold and calculating how many samples are there in total). This could work but is probably way more complex but closer to how this behaves now.
Note that OP's example is just the tip of the iceberg, actually all data can be dropped if we choose shards, workers and batch_sizes accordingly:
def convert_to_str(batch, dataset_name):
batch["a"] = [f"{dataset_name}-{e}" for e in batch["a"]]
return batch
number = 16 # 15 samples (1-15)
def gen1():
for ii in range(1, number):
yield {"a": ii}
def gen2():
for ii in range(1, number):
yield {"a": ii}
if __name__ == "__main__":
print("=" * 40)
print("num_workers=1")
print("=" * 40)
dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=3)
dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=3)
dataset1 = dataset1.map(
lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=9, drop_last_batch=True
)
dataset2 = dataset2.map(
lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=9, drop_last_batch=True
)
from datasets import interleave_datasets
interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted")
loader = DataLoader(interleaved, batch_size=5, num_workers=1)
i = 0
for b in loader:
print(i, b["a"])
i += 1
print()
print("=" * 40)
print("num_workers=3")
print("=" * 40)
dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=3)
dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=3)
dataset1 = dataset1.map(
lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=9, drop_last_batch=True
)
dataset2 = dataset2.map(
lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=9, drop_last_batch=True
)
interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted")
loader = DataLoader(interleaved, batch_size=5, num_workers=3)
i = 0
for b in loader:
print(i, b["a"])
i += 1
if i == 0:
print("Everything got dropped!")
========================================
num_workers=1
========================================
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
2 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
3 ['b-8', 'a-9', 'b-9']
========================================
num_workers=3
========================================
Everything got dropped!
EDIT: I looked into this a bit more and I revert my stance on the solutions. I think solution one is not feasible since we divide into shards before we know the batch_size. That leaves only option 2 on the table AFAIS right now.