datasets icon indicating copy to clipboard operation
datasets copied to clipboard

PyArrow Dataset error when calling `load_dataset`

Open piraka9011 opened this issue 3 years ago • 3 comments

Describe the bug

I am fine tuning a wav2vec2 model following the script here using my own dataset: https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py

Loading my Audio dataset from the hub which was originally generated from disk results in the following PyArrow error:

File "/home/ubuntu/w2v2/run_speech_recognition_ctc.py", line 227, in main
  raw_datasets = load_dataset(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/load.py", line 1679, in load_dataset
  builder_instance.download_and_prepare(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/builder.py", line 704, in download_and_prepare
  self._download_and_prepare(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/builder.py", line 793, in _download_and_prepare
  self._prepare_split(split_generator, **prepare_split_kwargs)
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/builder.py", line 1268, in _prepare_split
  for key, table in logging.tqdm(
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/tqdm/std.py", line 1195, in __iter__
  for obj in iterable:
File "/home/ubuntu/.virtualenvs/meval/lib/python3.8/site-packages/datasets/packaged_modules/parquet/parquet.py", line 68, in _generate_tables
  for batch_idx, record_batch in enumerate(
File "pyarrow/_parquet.pyx", line 1309, in iter_batches
File "pyarrow/error.pxi", line 121, in pyarrow.lib.check_status
pyarrow.lib.ArrowNotImplementedError: Nested data conversions not implemented for chunked array outputs

Steps to reproduce the bug

I created a dataset from a JSON lines manifest of audio_filepath, text, and duration.

When creating the dataset, I do something like this:

import json
from datasets import Dataset, Audio

# manifest_lines is a list of dicts w/ "audio_filepath", "duration", and "text
for line in manifest_lines:
    line = line.strip()
    if line:
        line_dict = json.loads(line)
        manifest_dict["audio"].append(f"{root_path}/{line_dict['audio_filepath']}")
        manifest_dict["duration"].append(line_dict["duration"])
        manifest_dict["transcription"].append(line_dict["text"])

# Create a HF dataset
dataset = Dataset.from_dict(manifest_dict).cast_column(
    "audio", Audio(sampling_rate=16_000),
)

# From the docs for saving to disk
# https://huggingface.co/docs/datasets/v2.3.2/en/package_reference/main_classes#datasets.Dataset.save_to_disk
def read_audio_file(example):
    with open(example["audio"]["path"], "rb") as f:
        return {"audio": {"bytes": f.read()}}

dataset = dataset.map(read_audio_file, num_proc=70)
dataset.save_to_disk(f"/audio-data/hf/{artifact_name}")
dataset.push_to_hub(f"{org-name}/{artifact_name}", max_shard_size="5GB", private=True)

Then when I call load_dataset() in my training script, with the same dataset I generated above, and download from the huggingface hub I get the above stack trace. I am able to load the dataset fine if I use load_from_disk().

Expected results

load_dataset() should behave just like load_from_disk() and not cause any errors.

Actual results

See above

Environment info

I am using the huggingface/transformers-pytorch-gpu:latest image

  • datasets version: 2.3.0
  • Platform: Docker/Ubuntu 20.04
  • Python version: 3.8
  • PyArrow version: 8.0.0

piraka9011 avatar Jul 20 '22 01:07 piraka9011

Hi ! It looks like a bug in pyarrow. If you manage to end up with only one chunk per parquet file it should workaround this issue.

To achieve that you can try to lower the value of max_shard_size and also don't use map before push_to_hub.

Do you have a minimum reproducible example that we can share with the Arrow team for further debugging ?

lhoestq avatar Jul 21 '22 17:07 lhoestq

If you manage to end up with only one chunk per parquet file it should workaround this issue.

Yup, I did not encounter this bug when I was testing my script with a slice of <1000 samples for my dataset.

Do you have a minimum reproducible example...

Not sure if I can get more minimal than the script I shared above. Are you asking for a sample json file? Just generate a random manifest list, I can add that to the above script if that's what you mean?

piraka9011 avatar Jul 21 '22 18:07 piraka9011

Actually this is probably linked to this open issue: https://issues.apache.org/jira/browse/ARROW-5030.

setting max_shard_size="2GB" should do the job (or max_shard_size="1GB" if you want to be on the safe side, especially given that there can be some variance in the shard sizes if the dataset is not evenly distributed)

lhoestq avatar Jul 22 '22 13:07 lhoestq