datasets icon indicating copy to clipboard operation
datasets copied to clipboard

Split dataset by node: index error when sharding iterable dataset

Open sanchit-gandhi opened this issue 1 year ago • 3 comments

Describe the bug

Context: we're splitting an iterable dataset by node and then passing it to a torch data loader with multiple workers

When we iterate over it for 5 steps, we don't get an error

When we instead iterate over it for 8 steps, we get an IndexError when fetching the data if we have too many workers

Steps to reproduce the bug

Here, we have 2 JAX processes (jax.process_count() = 2) which we split the dataset over. The dataset loading script can be found here: https://huggingface.co/datasets/distil-whisper/librispeech_asr/blob/c6a1e805cbfeed5057400ac5937327d7e30281b8/librispeech_asr.py#L310

Code to reproduce
from datasets import load_dataset
import jax
from datasets.distributed import split_dataset_by_node
from torch.utils.data import DataLoader
from tqdm import tqdm

# load an example dataset (https://huggingface.co/datasets/distil-whisper/librispeech_asr)
dataset = load_dataset("distil-whisper/librispeech_asr", "all", split="train.clean.100", streaming=True)
# just keep the text column -> no need to define a collator
dataset_text = dataset.remove_columns(set(dataset.features.keys()) - {"text"})

# define some constants
batch_size = 256
num_examples = 5  # works for 5 examples, doesn't for 8
num_workers = dataset_text.n_shards

# try with multiple workers
dataloader = DataLoader(dataset_text, batch_size=batch_size, num_workers=num_workers, drop_last=True)

for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Multiple workers"):
    if i == num_examples:
        break

# try splitting by node (we can't do this with `dataset_text` since `split_dataset_by_node` expects the Audio column for an ASR dataset)
dataset = split_dataset_by_node(dataset, rank=jax.process_index(), world_size=jax.process_count())
# remove the text column again
dataset_text = dataset.remove_columns(set(dataset.features.keys()) - {"text"})
dataloader = DataLoader(dataset_text, batch_size=16, num_workers=num_workers // 2, drop_last=True)

for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Split by node"):
    if i == num_examples:
        break

# too many workers
dataloader = DataLoader(dataset_text, batch_size=256, num_workers=num_workers, drop_last=True)
for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Too many workers"):
    if i == num_examples:
        break
With 5 examples:
Multiple workers: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:16<00:00,  3.33s/it]
Assigning 7 shards (or data sources) of the dataset to each node.                                                       
Split by node: 100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:13<00:00,  2.76s/it]
Assigning 7 shards (or data sources) of the dataset to each node.                                                       
Too many dataloader workers: 14 (max is dataset.n_shards=7). Stopping 7 dataloader workers.                             
To parallelize data loading, we give each process some shards (or data sources) to process. Therefore it's unnecessary t
o have a number of workers greater than dataset.n_shards=7. To enable more parallelism, please split the dataset in more
 files than 7.                                                                                                          
Too many workers: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:15<00:00,  3.03s/it]
With 7 examples:
Multiple workers: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00,  1.71s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Split by node: 100%|██████████████████████████████████████████████████████████████████████| 8/8 [00:11<00:00,  1.38s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Too many dataloader workers: 14 (max is dataset.n_shards=7). Stopping 7 dataloader workers.
To parallelize data loading, we give each process some shards (or data sources) to process. Therefore it's unnecessary to have a number of workers greater than dataset.n_shards=7. To enable more parallelism, please split the dataset in more files than 7.
Too many workers:  88%|██████████████████████████████████████████████████████████▋        | 7/8 [00:13<00:01,  1.89s/it]
Traceback (most recent call last):
  File "distil-whisper/test_librispeech.py", line 36, in <module>
    for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Too many workers"):
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1325, in _next_data
    return self._process_data(data)
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 7.
Original Traceback (most recent call last):
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
    data.append(next(self.dataset_iter))
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 986, in __iter__
    yield from self._iter_pytorch(ex_iterable)
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 920, in _iter_pytorch
    for key, example in ex_iterable.shard_data_sources(worker_info.id, worker_info.num_workers):
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 540, in shard_data_sources
    self.ex_iterable.shard_data_sources(worker_id, num_workers),
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 796, in shard_data_sources
    self.ex_iterable.shard_data_sources(worker_id, num_workers),
  File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 126, in shard_data_sources
    requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices])
  File "/home/sanchitgandhi/datasets/src/datasets/utils/sharding.py", line 76, in _merge_gen_kwargs
    for key in gen_kwargs_list[0]
IndexError: list index out of range

Expected behavior

Should pass for both 5 and 7 examples

Environment info

  • datasets version: 2.12.1.dev0
  • Platform: Linux-5.13.0-1023-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.14.1
  • PyArrow version: 12.0.0
  • Pandas version: 2.0.1

sanchit-gandhi avatar May 22 '23 10:05 sanchit-gandhi

cc @lhoestq in case you have any ideas here! Might need a multi-host set-up to debug (can give you access to a JAX one if you need)

sanchit-gandhi avatar May 22 '23 10:05 sanchit-gandhi

I am also facing the same problem. Could you let me know if you found a solution for this?

Munikumar09 avatar Jan 28 '24 10:01 Munikumar09

I couldn't reproduce with the latest version of datasets 2.16.1, can you update datasets and try again ?

lhoestq avatar Jan 29 '24 14:01 lhoestq