datasets
datasets copied to clipboard
Split dataset by node: index error when sharding iterable dataset
Describe the bug
Context: we're splitting an iterable dataset by node and then passing it to a torch data loader with multiple workers
When we iterate over it for 5 steps, we don't get an error
When we instead iterate over it for 8 steps, we get an IndexError
when fetching the data if we have too many workers
Steps to reproduce the bug
Here, we have 2 JAX processes (jax.process_count() = 2
) which we split the dataset over. The dataset loading script can be found here: https://huggingface.co/datasets/distil-whisper/librispeech_asr/blob/c6a1e805cbfeed5057400ac5937327d7e30281b8/librispeech_asr.py#L310
Code to reproduce
from datasets import load_dataset
import jax
from datasets.distributed import split_dataset_by_node
from torch.utils.data import DataLoader
from tqdm import tqdm
# load an example dataset (https://huggingface.co/datasets/distil-whisper/librispeech_asr)
dataset = load_dataset("distil-whisper/librispeech_asr", "all", split="train.clean.100", streaming=True)
# just keep the text column -> no need to define a collator
dataset_text = dataset.remove_columns(set(dataset.features.keys()) - {"text"})
# define some constants
batch_size = 256
num_examples = 5 # works for 5 examples, doesn't for 8
num_workers = dataset_text.n_shards
# try with multiple workers
dataloader = DataLoader(dataset_text, batch_size=batch_size, num_workers=num_workers, drop_last=True)
for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Multiple workers"):
if i == num_examples:
break
# try splitting by node (we can't do this with `dataset_text` since `split_dataset_by_node` expects the Audio column for an ASR dataset)
dataset = split_dataset_by_node(dataset, rank=jax.process_index(), world_size=jax.process_count())
# remove the text column again
dataset_text = dataset.remove_columns(set(dataset.features.keys()) - {"text"})
dataloader = DataLoader(dataset_text, batch_size=16, num_workers=num_workers // 2, drop_last=True)
for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Split by node"):
if i == num_examples:
break
# too many workers
dataloader = DataLoader(dataset_text, batch_size=256, num_workers=num_workers, drop_last=True)
for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Too many workers"):
if i == num_examples:
break
With 5 examples:
Multiple workers: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:16<00:00, 3.33s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Split by node: 100%|██████████████████████████████████████████████████████████████████████| 5/5 [00:13<00:00, 2.76s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Too many dataloader workers: 14 (max is dataset.n_shards=7). Stopping 7 dataloader workers.
To parallelize data loading, we give each process some shards (or data sources) to process. Therefore it's unnecessary t
o have a number of workers greater than dataset.n_shards=7. To enable more parallelism, please split the dataset in more
files than 7.
Too many workers: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:15<00:00, 3.03s/it]
With 7 examples:
Multiple workers: 100%|███████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00, 1.71s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Split by node: 100%|██████████████████████████████████████████████████████████████████████| 8/8 [00:11<00:00, 1.38s/it]
Assigning 7 shards (or data sources) of the dataset to each node.
Too many dataloader workers: 14 (max is dataset.n_shards=7). Stopping 7 dataloader workers.
To parallelize data loading, we give each process some shards (or data sources) to process. Therefore it's unnecessary to have a number of workers greater than dataset.n_shards=7. To enable more parallelism, please split the dataset in more files than 7.
Too many workers: 88%|██████████████████████████████████████████████████████████▋ | 7/8 [00:13<00:01, 1.89s/it]
Traceback (most recent call last):
File "distil-whisper/test_librispeech.py", line 36, in <module>
for i, batch in tqdm(enumerate(dataloader), total=num_examples, desc="Too many workers"):
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/tqdm/std.py", line 1178, in __iter__
for obj in iterable:
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1325, in _next_data
return self._process_data(data)
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise
raise exception
IndexError: Caught IndexError in DataLoader worker process 7.
Original Traceback (most recent call last):
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/sanchitgandhi/hf/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
data.append(next(self.dataset_iter))
File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 986, in __iter__
yield from self._iter_pytorch(ex_iterable)
File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 920, in _iter_pytorch
for key, example in ex_iterable.shard_data_sources(worker_info.id, worker_info.num_workers):
File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 540, in shard_data_sources
self.ex_iterable.shard_data_sources(worker_id, num_workers),
File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 796, in shard_data_sources
self.ex_iterable.shard_data_sources(worker_id, num_workers),
File "/home/sanchitgandhi/datasets/src/datasets/iterable_dataset.py", line 126, in shard_data_sources
requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices])
File "/home/sanchitgandhi/datasets/src/datasets/utils/sharding.py", line 76, in _merge_gen_kwargs
for key in gen_kwargs_list[0]
IndexError: list index out of range
Expected behavior
Should pass for both 5 and 7 examples
Environment info
-
datasets
version: 2.12.1.dev0 - Platform: Linux-5.13.0-1023-gcp-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.14.1
- PyArrow version: 12.0.0
- Pandas version: 2.0.1
cc @lhoestq in case you have any ideas here! Might need a multi-host set-up to debug (can give you access to a JAX one if you need)
I am also facing the same problem. Could you let me know if you found a solution for this?
I couldn't reproduce with the latest version of datasets
2.16.1, can you update datasets
and try again ?