Memory leak / Large memory usage with num_workers = 0 and numerous dataset within DatasetDict
Describe the bug
Hi team, first off, I love the datasets library! 🥰
I'm encountering a potential memory leak / increasing memory usage when training a model on a very large DatasetDict.
Setup: I have a DatasetDict containing 362 distinct datasets, which sum up to ~2.8 billion rows.
Training Task: I'm performing contrastive learning with SentenceTransformer and Accelerate on a single node with 4 H100, which requires me to sample from only one dataset at a time.
Training Loop: At each training step, I sample ~16,000 examples from a single dataset, and then switch to a different dataset for the next step. I iterate through all 362 datasets this way.
Problem: The process's memory usage continuously increases over time, eventually causing a stale status where GPUs would stop working. It seems memory from previously sampled datasets isn't being released. I've set num_workers=0 for all experiments.
Chart 1: Standard DatasetDict The memory usage grows steadily until it make the training stale (RSS memory)
Chart 2: IterableDatasetDict I also tried to use IterableDatasetDict and IterableDataset. The memory curve is "smoother," but the result is the same: it grows indefinitely and the training become stale.
Any feedback or guidance on how to manage this memory would be greatly appreciated!
Steps to reproduce the bug
WIP, I'll add some code that manage to reproduce this error, but not straightforward.
Expected behavior
The memory usage should remain relatively constant or plateau after a few steps. Memory used for sampling one dataset should be released before or during the sampling of the next dataset.
Environment info
Python: 3.12 Datasets: 4.3.0 SentenceTransformers: 5.1.1
Thanks for the report, this is possibly related #7722 and #7694.
Could you pls provide steps to reproduce this?
To overcome this issue right now I did simply reduce the size of the dataset and ended up running a for loop (my training has now a constant learning rate schedule). From what I understood, and I don't know if it's possible, the solution would be to tell the backend of datasets to leave x% of the memory free (including memory mapping). Can't release the data right now but I will and then allow to reproduce this issue. But it will involve to have some free TB of disk
@raphaelsty thanks for coming back to this. I assume you are running in streaming mode? That should prevent these errors but it looks like more people than just you have this problem, so a clearly reproducing example (including data + code) is highly appreciated.
This could be related to this issue: https://github.com/huggingface/datasets/issues/4883 in which we discussed how RSS and memory mapping works and depends on the OS and disk type.