datasets
datasets copied to clipboard
`map` with `num_proc` > 1 leads to OOM
Describe the bug
When running map
on parquet dataset loaded from local machine, the RAM usage increases linearly eventually leading to OOM. I was wondering if I should I save the cache_file
after every n steps in order to prevent this?
Steps to reproduce the bug
ds = load_dataset("parquet", data_files=dataset_path, split="train")
ds = ds.shard(num_shards=4, index=0)
ds = ds.cast_column("audio", datasets.features.Audio(sampling_rate=16_000))
ds = ds.map(prepare_dataset,
num_proc=32,
writer_batch_size=1000,
keep_in_memory=False,
desc="preprocess dataset")
def prepare_dataset(batch):
# load audio
sample = batch["audio"]
inputs = feature_extractor(sample["array"], sampling_rate=16000)
batch["input_values"] = inputs.input_values[0]
batch["input_length"] = len(sample["array"].squeeze())
return batch
Expected behavior
It shouldn't run into OOM problem.
Environment info
-
datasets
version: 2.18.0 - Platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.17
- Python version: 3.8.19
-
huggingface_hub
version: 0.22.2 - PyArrow version: 15.0.2
- Pandas version: 2.0.3
-
fsspec
version: 2024.2.0
Hi ! You can try to reduce writer_batch_size
. It corresponds to the number of samples that stay in RAM before being flushed to disk