datasets icon indicating copy to clipboard operation
datasets copied to clipboard

`drop_last_batch` does not drop the last batch using IterableDataset + interleave_datasets + multi_worker

Open memray opened this issue 10 months ago • 3 comments

Describe the bug

See the script below drop_last_batch=True is defined using map() for each dataset. The last batch for each dataset is expected to be dropped, id 21-25. The code behaves as expected when num_workers=0 or 1. When using num_workers>1, 'a-11', 'b-11', 'a-12', 'b-12' are gone and instead 21 and 22 are sampled.

Steps to reproduce the bug

from datasets import Dataset
from datasets import interleave_datasets
from torch.utils.data import DataLoader

def convert_to_str(batch, dataset_name):
    batch['a'] = [f"{dataset_name}-{e}" for e in batch['a']]
    return batch

def gen1():
    for ii in range(1, 25):
        yield {"a": ii}

def gen2():
    for ii in range(1, 25):
        yield {"a": ii}

# https://github.com/huggingface/datasets/issues/6565
if __name__ == '__main__':
    dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=2)
    dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=2)
    dataset1 = dataset1.map(lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=10, drop_last_batch=True)
    dataset2 = dataset2.map(lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=10, drop_last_batch=True)

    interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted")

    print(f"num_workers=0")
    loader = DataLoader(interleaved, batch_size=5, num_workers=0)
    i = 0
    for b in loader:
        print(i, b['a'])
        i += 1

    print('=-' * 20)
    print(f"num_workers=1")
    loader = DataLoader(interleaved, batch_size=5, num_workers=1)
    i = 0
    for b in loader:
        print(i, b['a'])
        i += 1

    print('=-' * 20)
    print(f"num_workers=2")
    loader = DataLoader(interleaved, batch_size=5, num_workers=2)
    i = 0
    for b in loader:
        print(i, b['a'])
        i += 1

    print('=-' * 20)
    print(f"num_workers=3")
    loader = DataLoader(interleaved, batch_size=5, num_workers=3)
    i = 0
    for b in loader:
        print(i, b['a'])
        i += 1

output is:

num_workers=0
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
2 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
3 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
4 ['a-11', 'b-11', 'a-12', 'b-12', 'a-13']
5 ['b-13', 'a-14', 'b-14', 'a-15', 'b-15']
6 ['a-16', 'b-16', 'a-17', 'b-17', 'a-18']
7 ['b-18', 'a-19', 'b-19', 'a-20', 'b-20']
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
num_workers=1
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
2 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
3 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
4 ['a-11', 'b-11', 'a-12', 'b-12', 'a-13']
5 ['b-13', 'a-14', 'b-14', 'a-15', 'b-15']
6 ['a-16', 'b-16', 'a-17', 'b-17', 'a-18']
7 ['b-18', 'a-19', 'b-19', 'a-20', 'b-20']
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
num_workers=2
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['a-13', 'b-13', 'a-14', 'b-14', 'a-15']
2 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
3 ['b-15', 'a-16', 'b-16', 'a-17', 'b-17']
4 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
5 ['a-18', 'b-18', 'a-19', 'b-19', 'a-20']
6 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
7 ['b-20', 'a-21', 'b-21', 'a-22', 'b-22']
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
num_workers=3
Too many dataloader workers: 3 (max is dataset.num_shards=2). Stopping 1 dataloader workers.
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['a-13', 'b-13', 'a-14', 'b-14', 'a-15']
2 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
3 ['b-15', 'a-16', 'b-16', 'a-17', 'b-17']
4 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
5 ['a-18', 'b-18', 'a-19', 'b-19', 'a-20']
6 ['b-8', 'a-9', 'b-9', 'a-10', 'b-10']
7 ['b-20', 'a-21', 'b-21', 'a-22', 'b-22']


Expected behavior

'a-21', 'b-21', 'a-22', 'b-22' should be dropped

Environment info

  • datasets version: 3.3.2
  • Platform: Linux-5.15.0-1056-aws-x86_64-with-glibc2.31
  • Python version: 3.10.16
  • huggingface_hub version: 0.28.0
  • PyArrow version: 19.0.0
  • Pandas version: 2.2.3
  • fsspec version: 2024.6.1

memray avatar Mar 08 '25 10:03 memray

Hi @memray, I’d like to help fix the issue with drop_last_batch not working when num_workers > 1. I’ll investigate and propose a solution. Thanks!

Rawdyrathaur avatar Mar 09 '25 05:03 Rawdyrathaur

Thank you very much for offering to help! I also noticed a problem related to a previous issue and left a comment here (the code checks the validity before certain columns removed). Can you take a look as well?

memray avatar Mar 09 '25 21:03 memray

I looked into this and the problem here seems to be the order of sharding and batching/or how drop_last_batch is done (see the potential solutions below if unclear). Since we have 2 workers and 2 shards the data is split into 1-12 on worker 1 and 13-24 on worker 2. Now each of those workers iterates in batches of 10 and drops the last element, therefore worker 1 drops {11, 12} and worker 2 {23, 24}. There are multiple ways to circumvent that:

  • distribute batches in turns to workers and tell workers if they should drop the batches individually, so that only the last worker drops anything
  • distribute data as right now but telling each worker how many samples to drop individually (but that would require each worker to know how many samples they hold and calculating how many samples are there in total). This could work but is probably way more complex but closer to how this behaves now.

Note that OP's example is just the tip of the iceberg, actually all data can be dropped if we choose shards, workers and batch_sizes accordingly:

def convert_to_str(batch, dataset_name):
    batch["a"] = [f"{dataset_name}-{e}" for e in batch["a"]]
    return batch


number = 16  # 15 samples (1-15)


def gen1():
    for ii in range(1, number):
        yield {"a": ii}


def gen2():
    for ii in range(1, number):
        yield {"a": ii}

if __name__ == "__main__":
    print("=" * 40)
    print("num_workers=1")
    print("=" * 40)
    dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=3)
    dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=3)
    dataset1 = dataset1.map(
        lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=9, drop_last_batch=True
    )
    dataset2 = dataset2.map(
        lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=9, drop_last_batch=True
    )

    from datasets import interleave_datasets

    interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted")

    loader = DataLoader(interleaved, batch_size=5, num_workers=1)
    i = 0
    for b in loader:
        print(i, b["a"])
        i += 1

    print()
    print("=" * 40)
    print("num_workers=3")
    print("=" * 40)
    dataset1 = Dataset.from_generator(gen1).to_iterable_dataset(num_shards=3)
    dataset2 = Dataset.from_generator(gen2).to_iterable_dataset(num_shards=3)
    dataset1 = dataset1.map(
        lambda x: convert_to_str(x, dataset_name="a"), batched=True, batch_size=9, drop_last_batch=True
    )
    dataset2 = dataset2.map(
        lambda x: convert_to_str(x, dataset_name="b"), batched=True, batch_size=9, drop_last_batch=True
    )

    interleaved = interleave_datasets([dataset1, dataset2], stopping_strategy="all_exhausted")

    loader = DataLoader(interleaved, batch_size=5, num_workers=3)
    i = 0
    for b in loader:
        print(i, b["a"])
        i += 1

    if i == 0:
        print("Everything got dropped!")
========================================
num_workers=1
========================================
0 ['a-1', 'b-1', 'a-2', 'b-2', 'a-3']
1 ['b-3', 'a-4', 'b-4', 'a-5', 'b-5']
2 ['a-6', 'b-6', 'a-7', 'b-7', 'a-8']
3 ['b-8', 'a-9', 'b-9']

========================================
num_workers=3
========================================
Everything got dropped!

EDIT: I looked into this a bit more and I revert my stance on the solutions. I think solution one is not feasible since we divide into shards before we know the batch_size. That leaves only option 2 on the table AFAIS right now.

CloseChoice avatar Oct 09 '25 05:10 CloseChoice