datasets
datasets copied to clipboard
With dataloader RSS memory consumed by HF datasets monotonically increases
Describe the bug
When the HF datasets is used in conjunction with PyTorch Dataloader, the RSS memory of the process keeps on increasing when it should stay constant.
Steps to reproduce the bug
Run and observe the output of this snippet which logs RSS memory.
import psutil
import os
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_TRIES = 10
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
x.pop("text")
x.pop("label")
return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
count = 0
while count < NUM_TRIES:
for idx, batch in enumerate(train_loader):
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(count, idx, mem_after - mem_before)
count += 1
Expected results
Memory should not increase after initial setup and loading of the dataset
Actual results
Memory continuously increases as can be seen in the log.
Environment info
datasetsversion: 2.3.2- Platform: Linux-4.19.0-21-cloud-amd64-x86_64-with-glibc2.10
- Python version: 3.8.13
- PyArrow version: 7.0.0
Are you sure there is a leak? How can I see it? You shared the script but not the output which you believe should indicate a leak.
I modified your reproduction script to print only once per try as your original was printing too much info and you absolutely must add gc.collect() when doing any memory measurements, since python's GC is scheduled so you might be measuring the wrong thing. This gives us:
import psutil
import os
import gc
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_TRIES = 100
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
x.pop("text")
x.pop("label")
return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
count = 0
while count < NUM_TRIES:
for idx, batch in enumerate(train_loader): pass
gc.collect()
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(count, mem_after - mem_before)
count += 1
Now running it:
$ python dl-leak.py
Reusing dataset imdb (/home/stas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
0 4.43359375
1 4.4453125
2 4.44921875
3 4.44921875
4 4.4609375
5 4.46484375
6 4.46484375
7 4.46484375
8 4.46484375
9 4.46484375
10 4.46484375
11 4.46484375
12 4.46484375
13 4.46484375
14 4.46484375
15 4.46484375
16 4.46484375
It's normal that at the beginning there is a small growth in memory usage, but after 5 cycles it gets steady.
Unless of course you're referring the memory growth during the first try. Is that what you're referring to? And since your ds is small it's hard to see the growth - could it be just because some records are longer and it needs to allocate more memory for those?
Though while experimenting with this I have observed a peculiar thing, if I concatenate 2 datasets, I don't see any growth at all. But that's probably because the program allocated additional peak RSS memory to concatenate and then is re-using the memory
I basically tried to see if I make the dataset much longer, I'd expect not to see any memory growth once the 780 records of the imdb ds have been processed once.
It is hard to say if it is directly reproducible in this setup. Maybe it is specific to the images stored in the CM4 case which cause a memory leak. I am still running your script and seeing if I can reproduce that particular leak in this case.
I was able to reproduce the leak with:
import psutil
import os
import gc
from datasets import load_from_disk
import time
DATASET_PATH = "/hf/m4-master/data/cm4/cm4-10000-v0.1"
dataset = load_from_disk(DATASET_PATH)
# truncate to a tiny dataset
dataset = dataset.select(range(1000))
print(f"dataset: {len(dataset)} records")
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, rec in enumerate(dataset):
if idx % 100 == 0:
gc.collect()
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
You need to adjust the DATASET_PATH record.
which you get from
gsutil -m cp "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset.arrow" "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset_info.json" "gs://hf-science-m4/cm4/cm4-10000-v0.1/state.json" .
(I assume the hf folks have the perms) - it's a smallish dataset (10k)
then you run:
$ python ds.py
dataset: 1000 records
0 1.0156MB
100 126.3906MB
200 142.8906MB
300 168.5586MB
400 218.3867MB
500 230.7070MB
600 238.9570MB
700 263.3789MB
800 288.1289MB
900 300.5039MB
you should be able to see the leak
This issue has nothing to do with PIL's decoder. I removed it and the problem is still there.
I then traced this leak to this single call: pa_table.to_pydict() here:
https://github.com/huggingface/datasets/blob/08a7b389cdd6fb49264a72aa8ccfc49a233494b6/src/datasets/formatting/formatting.py#L138-L140
I can make it leak much faster by modifying that code to repeat pa_table.to_pydict() many times in a row. It shouldn't have that impact:
class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]):
def extract_row(self, pa_table: pa.Table) -> dict:
x = [pa_table.to_pydict() for x in range(200)]
return _unnest(pa_table.to_pydict())
@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?
Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.
The problem already happens with pyarrow==6.0.0 or later (minimum for current datasets)
I'm also trying to dig in with objgraph to see if there are any circular references which prevent objects from being freed, but no luck there so far. And I'm pretty sure to_pydict is not a python code, so the problem is likely to happen somewhere outside of python's GC.
This appears to be the same issue I think: https://github.com/huggingface/datasets/issues/4528 I dug into the repro code there and it's the same behavior with the same leak, but it's a pure nlp dataset and thus much faster to work with.
I went all the way back to pyarrow==1.0.0 and datasets==1.12.0 and the problem is still there. How is it even possible that it wasn't noticed all this time.
Could it be that the leak is in some 3rd party component pyarrow relies on? as while downgrading I have only downgraded the above 2 packages.
Also found this warning
Be careful: if you don't pass the ArrowArray struct to a consumer, array memory will leak. This is a low-level function intended for expert users.
see: https://github.com/apache/arrow/blob/99b57e84277f24e8ec1ddadbb11ef8b4f43c8c89/python/pyarrow/table.pxi#L2515-L2517
perhaps something triggers this condition?
I have no idea if it's related - this is just something that came up during my research.
Does it crash with OOM at some point? If it doesn't, it isn't a leak, just agressive caching or a custom allocator that doesn't like to give memory back (not uncommon). #4528 looks like it hits a steady state.
I believe the underlying arrow libs use a custom C allocator. Some of those are designed not to give back to OS, but keep heap memory for themselves to re-use (hitting up the OS involves more expensive mutex locks, contention, etc). The greedy behaviour can be undesirable though. There are likely flags to change the allocator behaviour, and one could likely build without any custom allocators (or use a different one).
Does it crash with OOM at some point?
In the original setup where we noticed this problem, it was indeed ending in an OOM
https://github.com/huggingface/datasets/issues/4528 looks like it hits a steady state.
@rwightman in the plot I shared, the steady state comes from the time.sleep(100) I added in the end of the script, to showcase that even the garbage collector couldn't free that allocated memory.
Could this be related to this discussion about a potential memory leak in pyarrow: https://issues.apache.org/jira/browse/ARROW-11007 ?
(Note: I've tried import pyarrow; pyarrow.jemalloc_set_decay_ms(0) and the memory leak is still happening on your toy example)
@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?
to_pydict calls to_pylist on each column (i.e. on each PyArrow Array). Then it iterates on the array and calls as_py on each element. The as_py implementation depends on the data type. For strings I think it simply gets the buffer that contains the binary string data that is defined in C++
The Arrow team is pretty responsive at [email protected] if it can help
Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.
That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?
That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?
I added you to the bucket @lhoestq
It looks like an issue with memory mapping:
- the amount of memory used in the end corresponds to the size of the dataset
- setting
keep_in_memory=Trueinload_from_diskloads the dataset in RAM, and doesn't cause any memory leak
Here is a code to reproduce this issue using only PyArrow and a dummy arrow file:
import psutil
import os
import gc
import pyarrow as pa
import time
ARROW_PATH = "tmp.arrow"
if not os.path.exists(ARROW_PATH):
arr = pa.array([b"a" * (200 * 1024)] * 1000) # ~200MB
table = pa.table({"a": arr})
with open(ARROW_PATH, "wb") as f:
writer = pa.RecordBatchStreamWriter(f, schema=table.schema)
writer.write_table(table)
writer.close()
def memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
memory_mapped_stream = pa.memory_map(filename)
opened_stream = pa.ipc.open_stream(memory_mapped_stream)
pa_table = opened_stream.read_all()
return pa_table
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, x in enumerate(arr):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
prints
0 0.2500MB
100 19.8008MB
200 39.3320MB
300 58.8633MB
400 78.3945MB
500 97.9258MB
600 117.4570MB
700 136.9883MB
800 156.5195MB
900 176.0508MB
Note that this example simply iterates over the pyarrow.lib.BinaryScalar objects in the array. Running .as_py() is not needed to experience the memory issue.
@lhoestq that does indeed increase in memory, but if you iterate over array again after the first time, or re-open and remap the same file (repeat table = memory_mapped_arrow_table_from_file(ARROW_PATH)) before re-iterating, it doesn't move pas 195MB.... it would appear another step is needed to continue consuming memory past that.. hmmm
Are the pa_tables held on to anywhere after they are iterated in the real code?
in my hack, if you do a bunch cut & paste and then change the arr name for each iter
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]
for idx, x in enumerate(arr):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr1 = table[0]
for idx, x in enumerate(arr1):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr2 = table[0]
for idx, x in enumerate(arr2):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")
it leaks, if all arr are the same name (so prev one gets cleaned up) it does not and goes back to 0, anything that could be holding onto a reference of an intermediary equivalent like arr in the real use case?
Yes, we have already established here https://github.com/huggingface/datasets/issues/4883#issuecomment-1232063891 that when one iterates over the whole dataset multiple times, it consumes a bit more memory in the next few repetitions and then remains steady.
Which means that when a new iterator is created over the same dataset, all the memory from the previous iterator is re-used.
So the leak happens primarily when the iterator is "drained" the first time. which tells me that either a circular reference is created somewhere which only gets released when the iterator is destroyed, or there is some global variable that keeps piling up the memory and doesn't release it in time.
Also I noticed some __del__ methods which won't destroy objects automatically and there is usually a warning against using it https://stackoverflow.com/a/1481512/9201239
There are also some weakrefs in the code which too may lead to leaks or weird problems at times.
@stas00 my point was, I'm not convinced @lhoestq last example illustrates the leak, but rather the differences between memory mapping and in memory usage patterns. If you destroy arr, memory map impl goes back to 0 each iteration. The amount of memory that 'looks' like it is leaked in first pass differes quite a bit between memory mapped vs in memory, but the underlying issue likely a circular reference, or reference(s) which were not cleaned up that would impact either case, but likely much more visible with mmap.
Thank you for clarifying, Ross.
I think we agree that it's almost certain that the datasets iterator traps some inner variable that prevents object freeing, since if we create the iterator multiple times (and drain it) after a few runs no new memory is allocated. We could try to dig in more with objgraph - my main concern is if the problem happens somewhere outside of python, (i.e. in pyarrow cpp implementation) in which case it'd be much more difficult to trace.
I wish there was a way on linux to tell the program to free no longer used memory at will.
FWIW, I revisted some code I had in the works to use HF datasets w/ timm train & val scripts. There is no leak there across multipe epochs. It uses the defaults.
It's worth noting that with imagenet keep_in_memory=True isn't even an option because the train arrow file is ~140GB and my local memory is less. The virtual address space reflects mmap (> 150GB) and doesn't increase over epochs that I noticed. I have some perf issues to bring up wrt to the current setup, but that's a separate and lower prio discussion to have elsewhere...
Notes
After reading many issues and trying many things here is the summary of my learning
I'm now using @lhoestq repro case as it's pyarrow-isolated: https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985
1. pyarrow memory backends
it has 3 backends, I tried them all with the same results
pa.set_memory_pool(pa.jemalloc_memory_pool())
pa.set_memory_pool(pa.mimalloc_memory_pool())
pa.set_memory_pool(pa.system_memory_pool())
2. quick release
The jemalloc backend supports quick release
pa.jemalloc_set_decay_ms(0)
it doesn't make any difference in this case
3. actual memory allocations
this is a useful tracer for PA memory allocators
pa.log_memory_allocations(enable=True)
it nicely reports memory allocations and releases when the arrow file is created the first time.
but when we then try to do enumerate(arr) this logger reports 0 allocations.
This summary also reports no allocations when the script run the second time (post file creation):
mem_pool = pa.default_memory_pool()
print(f"PyArrow mem pool info: {mem_pool.backend_name} backend, {mem_pool.bytes_allocated()} allocated, "
f"{mem_pool.max_memory()} max allocated, ")
print(f"PyArrow total allocated bytes: {pa.total_allocated_bytes()}")
However it's easy to see by using tracemalloc which only measures python's memory allocations that it's PA that leaks, since tracemalloc shows fixed memory
(this is bolted on top of the original repro script)
import tracemalloc
tracemalloc.start()
[...]
for idx, x in enumerate(arr):
if idx % 10 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
mem_use = pa.total_allocated_bytes() - start_use
mem_peak = pool.max_memory() - start_peak_use
second_size, second_peak = tracemalloc.get_traced_memory()
mem_diff = (second_size - first_size) / 2**20
mem_peak_diff = (second_peak - first_peak) / 2**20
# pa.jemalloc_memory_pool().release_unused()
# pa.mimalloc_memory_pool().release_unused()
# pa.system_memory_pool().release_unused()
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {memory_mapped_stream.size()/2**20:4.4}MB {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")
gives:
0 5.4258MB 0.0110 0.0201 195.3MB 0.0MB 0.0MB
10 25.3672MB 0.0112 0.0202 195.3MB 0.0MB 0.0MB
20 45.9336MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
30 62.4336MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
40 83.0586MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
50 103.6836MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
60 124.3086MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
70 140.8086MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
80 161.4336MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
90 182.0586MB 0.0112 0.0203 195.3MB 0.0MB 0.0MB
the 3rd and 4th columns are tracemalloc's report.
the 5th column is the size of mmaped stream - fixed.
the last 2 are the PA's malloc reports - you can see it's totally fixed and 0.
So what gives? PA's memory allocator says nothing was allocated and we can see python doesn't allocate any memory either.
As someone suggested in one of the PA issues that IPC/GRPC could be the issue. Any suggestions on how debug this one?
The main issue is that one can't step through with a python debugger as arr is an opaque cpp object binded to python.
Please see the next comment for a possible answer.
ref-count
I also traced reference counts and they are all fixed using either sys.getrefcount(x) or len(gc.get_referrers(x))
so it's not the python object
Important related discussions
https://issues.apache.org/jira/browse/ARROW-11007 - looks very similar to our issue in particular this part of the report: https://issues.apache.org/jira/browse/ARROW-11007?focusedCommentId=17279642&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-17279642
There is no leak, just badly communicated linux RSS memory usage stats
Next, lets revisit @rwightman's suggestion that there is actually no leak.
After all - we are using mmap which will try to map the file to RAM as much as it can and then page out if there is no memory. i.e. MMAP is only fast if you have a lot of CPU RAM.
So let's do it:
Memory mapping OOM test
We first quickly start a cgroups-controlled shell which will instantly kill any program that consumes more than 1GB of memory:
$ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=1G --setenv="MEMLIMIT=1GB" bash
Let's check that it indeed does so. Let's change @lhoestq's script to allocate a 10GB arrow file:
$ python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 5000)'
Killed
oops, that didn't work, as we tried to allocate 10GB when only 1GB is allowed. This is what we want!
Let's do a sanity check - can we allocate 0.1GB?
python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 50)'
Yes. So the limited shell does the right thing. It let's allocate < 1GB of RSS RAM.
Next let's go back to @lhoestq's script but with 10GB arrow file.
we change his repro script https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985 to 50x larger file
arr = pa.array([b"a" * (2000 * 1024)] * 5000) # ~10000MB
we first have to run into a normal unlimited shell so that we don't get killed (as the script allocates 10GB)
let's run the script now in the 1GB-limited shell while running a monitor:
$ htop -F python -s M_RESIDENT -u `whoami`
so we have 2 sources of RSS info just in case.
$ python pyar.py
0 4.3516MB 0.0103 0.0194 9.766e+03MB 0.0MB 0.0MB
10 24.3008MB 0.0104 0.0195 9.766e+03MB 0.0MB 0.0MB
[...]
4980 9730.3672MB 0.0108 0.0199 9.766e+03MB 0.0MB 0.0MB
4990 9750.9922MB 0.0108 0.0199 9.766e+03MB 0.0MB 0.0MB
PyArrow mem pool info: jemalloc backend, 0 allocated, 0 max allocated,
PyArrow total allocated bytes: 0
But wait, it reported 10GB RSS both in htop and in our log!
So that means it never allocated 10GB otherwise it'd have been killed.
Which tells us that there is no leak whatsoever and this is just a really difficult situation where MMAPPED memory is reported as part of RSS which it probably shouldn't. As now we have no way how to measure real memory usage.
I also attached the script with all the different things I have tried in it, so it should be easy to turn them on/off if you want to reproduce any of my findings.
just rename it to pyra.py as gh doesn't let attaching scripts...
(I have to remember to exit that special mem-limited shell or else I won't be able to do anything serious there.)
The original leak in the multi-modal code is very likely something else. But of course now it'd be very difficult to trace it using mmap.
I think to debug we have to set keep_in_memory=True in load_from_disk to load the small dataset in RAM, so there will be no mmap misleading reporting component and then continue searching for another source of a leak.
To add to what @stas00 found, I'm gonna leave some links to where I believe the confusion came from in pyarrow's APIs, for future reference:
- In the section where they talk about efficiently writing and reading arrow data, they give an example of how
Arrow can directly reference the data mapped from disk and avoid having to allocate its own memory.
And where their example shows 0 RSS memory allocation, the way we used to measure RSS shows 39.6719MB allocated. Here's the script to reproduce:
import psutil
import os
import gc
import pyarrow as pa
import time
import sys
# gc.set_debug(gc.DEBUG_LEAK)
# gc.set_threshold(0,0,0)
#pa.set_memory_pool(pa.mimalloc_memory_pool())
#pa.set_memory_pool(pa.system_memory_pool())
import tracemalloc
#pa.jemalloc_set_decay_ms(0)
# pa.log_memory_allocations(enable=True)
BATCH_SIZE = 10000
NUM_BATCHES = 1000
schema = pa.schema([pa.field('nums', pa.int32())])
with pa.OSFile('bigfile.arrow', 'wb') as sink:
with pa.ipc.new_file(sink, schema) as writer:
for row in range(NUM_BATCHES):
batch = pa.record_batch([pa.array(range(BATCH_SIZE), type=pa.int32())], schema)
writer.write(batch)
start_use = pa.total_allocated_bytes()
pool = pa.default_memory_pool()
start_peak_use = pool.max_memory()
tracemalloc.start()
first_size, first_peak = tracemalloc.get_traced_memory()
mem_before = psutil.Process(os.getpid()).memory_info().rss / 2**20
# with pa.OSFile('bigfile.arrow', 'rb') as source:
# loaded_array = pa.ipc.open_file(source).read_all()
with pa.memory_map('bigfile.arrow', 'rb') as source:
loaded_array = pa.ipc.open_file(source).read_all()
print("LEN:", len(loaded_array))
print("RSS: {}MB".format(pa.total_allocated_bytes() >> 20))
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
mem_use = pa.total_allocated_bytes() - start_use
mem_peak = pool.max_memory() - start_peak_use
second_size, second_peak = tracemalloc.get_traced_memory()
mem_diff = (second_size - first_size) / 2**20
mem_peak_diff = (second_peak - first_peak) / 2**20
idx = 0
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")
gives:
LEN: 10000000
RSS: 0MB
0 39.6719MB 0.0132 0.0529 0.0MB 0.0MB
Which again just proves that we uncorrectly measure RSS, in the case of MMAPPED memory
@lhoestq, I have been working on a detailed article that shows that MMAP doesn't leak and it's mostly ready. I will share when it's ready.
The issue is that we still need to be able to debug memory leaks by turning MMAP off.
But, once I tried to show the user that using load_dataset(... keep_in_memory=True) is the way to debug an actual memory leak - guess I what I discovered? A potential actual leak.
Here is the repro:
$ cat ds-mmap.py
from datasets import load_dataset
import gc
import os
import psutil
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=True, streaming=False)['train']
print(f"{'idx':>6} {'RSS':>10} {'Δ RSS':>15}")
step = 20000
for i in range(0, 10*step, step):
mem_before = mem_read()
_ = dataset[i:i+step]
mem_after = mem_read()
print(f"{i:6d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB ")
python ds-io.py
Reusing dataset wmt19 (/home/stas/.cache/huggingface/datasets/wmt19/cs-en/1.0.0/c3db1bf4240362ed1ef4673b354f468d70aac66d4e67d45f536d493a0840f0d3)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5.66it/s]
idx RSS Δ RSS
0 1398.4609MB 3.5195MB
20000 1398.5742MB 0.1133MB
40000 1398.6016MB 0.0273MB
60000 1398.6016MB 0.0000MB
80000 1398.6016MB 0.0000MB
100000 1398.6328MB 0.0312MB
120000 1398.6953MB 0.0625MB
140000 1398.6953MB 0.0000MB
160000 1398.7500MB 0.0547MB
180000 1398.7500MB 0.0000MB
as I suggested on slack perhaps it was due to dataset records length variation, so with your help I wrote another repro with synthetic records which are all identical - which should remove my hypothese from the equation and we should expect 0 incremental growth as we iterate over the datasets. But alas this is not the case. There is a tiny but definite leak-like behavior.
Here is the new repro:
$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil
proc = psutil.Process(os.getpid())
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
records = 1_000_000
print("Creating a synthetic dataset")
row = dict(foo=[dict(a='a'*500, b='b'*1000)])
ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
ds.save_to_disk(DS_PATH)
print("Done. Please restart the program")
sys.exit()
dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")
print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 25_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
if idx == warmup_iterations: # skip the first few iterations while things get set up
mem_start = mem_read()
mem_before = mem_read()
_ = dataset[i:i+step]
mem_after = mem_read()
print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()
print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")
and the run:
$ python ds-synthetic-no-mmap.py
Dataset len=1000000
idx RSS Δ RSS
0 1601.9258MB 47.9688MB
25000 1641.6289MB 39.7031MB
50000 1641.8594MB 0.2305MB
75000 1642.1289MB 0.2695MB
100000 1642.1289MB 0.0000MB
125000 1642.3789MB 0.2500MB
150000 1642.3789MB 0.0000MB
175000 1642.6289MB 0.2500MB
200000 1642.6289MB 0.0000MB
225000 1642.8789MB 0.2500MB
250000 1642.8828MB 0.0039MB
275000 1643.1328MB 0.2500MB
300000 1643.1328MB 0.0000MB
325000 1643.3828MB 0.2500MB
350000 1643.3828MB 0.0000MB
375000 1643.6328MB 0.2500MB
400000 1643.6328MB 0.0000MB
425000 1643.8828MB 0.2500MB
450000 1643.8828MB 0.0000MB
475000 1644.1328MB 0.2500MB
500000 1644.1328MB 0.0000MB
525000 1644.3828MB 0.2500MB
550000 1644.3828MB 0.0000MB
575000 1644.6328MB 0.2500MB
600000 1644.6328MB 0.0000MB
625000 1644.8828MB 0.2500MB
650000 1644.8828MB 0.0000MB
675000 1645.1328MB 0.2500MB
700000 1645.1328MB 0.0000MB
725000 1645.3828MB 0.2500MB
750000 1645.3828MB 0.0000MB
775000 1645.6328MB 0.2500MB
800000 1645.6328MB 0.0000MB
825000 1645.8828MB 0.2500MB
850000 1645.8828MB 0.0000MB
875000 1646.1328MB 0.2500MB
900000 1646.1328MB 0.0000MB
925000 1646.3828MB 0.2500MB
950000 1646.3828MB 0.0000MB
975000 1646.6328MB 0.2500MB
Total diff: 4.5039MB (after 4 warmup iterations)
so I'm still not sure why we get this.
As you can see I started skipping the first few iterations where memory isn't stable yet. As the actual diff is much larger if we count all iterations.
What do you think?
@stas00 my 2 cents from having looked at a LOT of memory leaks over the years, esp in Python, .3% memory increase over that many iterations of something is difficult to say with certainty it is a leak.
Also, just looking at RSS makes it hard to analyze leaks. RSS can stay near constant while you are leaking. RSS is paged in mem, if you have a big leak your RSS might not increase much (leaked mem tends not to get used again so often paged out) while your virtual page allocation could be going through the roof...
yes, that's true, but unless the leak is big, I'm yet to find another measurement tool.
To prove your point here is a very simple IO in a loop program that also reads the same line all over again:
$ cat mmap-no-leak-debug.py
import gc
import mmap
import os
import psutil
import sys
proc = psutil.Process(os.getpid())
PATH = "./tmp.txt"
def mem_read():
gc.collect()
return proc.memory_info().rss / 2**20
# create a large data file with a few long lines
if not os.path.exists(PATH):
with open(PATH, "w") as fh:
s = 'a'* 2**27 + "\n" # 128MB
# write ~2GB file
for i in range(16):
fh.write(s)
print(f"{'idx':>4} {'RSS':>10} {'Δ RSS':>12} {'Δ accumulated':>10}")
total_read = 0
content = ''
mem_after = mem_before_acc = mem_after_acc = mem_before = proc.memory_info().rss / 2**20
print(f"{0:4d} {mem_after:10.2f}MB {mem_after - 0:10.2f}MB {0:10.2f}MB")
mmap_mode = True if "--mmap" in sys.argv else False
with open(PATH, "r") as fh:
if mmap_mode:
mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)
idx = 0
while True:
idx += 1
mem_before = mem_read()
line = mm.readline() if mmap_mode else fh.readline()
if not line:
break
#total_read += len(line)
if "--accumulate" in sys.argv:
mem_before_acc = mem_read()
content += str(line)
mem_after_acc = mem_read()
mem_after = mem_read()
print(f"{idx:4d} {mem_after:10.2f}MB {mem_after - mem_before:10.2f}MB {mem_after_acc - mem_before_acc:10.2f}MB")
it has some other instrumentations to do mmap and accumulate data, but let's ignore that for now.
Here it is running in a simple non-mmap IO:
$ python mmap-no-leak-debug.py
idx RSS Δ RSS Δ accumulated
0 12.43MB 12.43MB 0.00MB
1 269.72MB 257.29MB 0.00MB
2 269.73MB 0.02MB 0.00MB
3 269.73MB 0.00MB 0.00MB
4 269.74MB 0.01MB 0.00MB
5 269.74MB 0.00MB 0.00MB
6 269.75MB 0.01MB 0.00MB
7 269.75MB 0.00MB 0.00MB
8 269.76MB 0.01MB 0.00MB
9 269.76MB 0.00MB 0.00MB
10 269.77MB 0.01MB 0.00MB
11 269.77MB 0.00MB 0.00MB
12 269.77MB 0.00MB 0.00MB
13 269.77MB 0.00MB 0.00MB
14 269.77MB 0.00MB 0.00MB
15 269.77MB 0.00MB 0.00MB
16 146.02MB -123.75MB 0.00MB
as you can see even this super-simplistic program that just performs readline() slightly increases in RSS over iterations.
If you have a better tool for measurement other than RSS, I'm all ears.
@stas00 if you aren't using memory maps, you should be able to clearly see the increase in the virtual mem for the process as well. Even then, it could still be challenging to determine if it's leak vs fragmentation due to problematic allocation patterns (not uncommon with Python). Using a better mem allocator like tcmalloc via LD_PRELOAD hooks could reduce impact of fragmentation across both Python and c libs. Not sure that plays nice with any allocator that arrow might use itself though.