datasets icon indicating copy to clipboard operation
datasets copied to clipboard

With dataloader RSS memory consumed by HF datasets monotonically increases

Open apsdehal opened this issue 3 years ago • 44 comments

Describe the bug

When the HF datasets is used in conjunction with PyTorch Dataloader, the RSS memory of the process keeps on increasing when it should stay constant.

Steps to reproduce the bug

Run and observe the output of this snippet which logs RSS memory.

import psutil
import os
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_TRIES = 10

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
    x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
    x.pop("text")
    x.pop("label")
    return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
count = 0
while count < NUM_TRIES:
    for idx, batch in enumerate(train_loader):
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(count, idx, mem_after - mem_before)
    count += 1

Expected results

Memory should not increase after initial setup and loading of the dataset

Actual results

Memory continuously increases as can be seen in the log.

Environment info

  • datasets version: 2.3.2
  • Platform: Linux-4.19.0-21-cloud-amd64-x86_64-with-glibc2.10
  • Python version: 3.8.13
  • PyArrow version: 7.0.0

apsdehal avatar Aug 24 '22 08:08 apsdehal

Are you sure there is a leak? How can I see it? You shared the script but not the output which you believe should indicate a leak.

I modified your reproduction script to print only once per try as your original was printing too much info and you absolutely must add gc.collect() when doing any memory measurements, since python's GC is scheduled so you might be measuring the wrong thing. This gives us:

import psutil
import os
import gc
from transformers import BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader

BATCH_SIZE = 32
NUM_TRIES = 100

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def transform(x):
    x.update(tokenizer(x["text"], return_tensors="pt", max_length=64, padding="max_length", truncation=True))
    x.pop("text")
    x.pop("label")
    return x
dataset = load_dataset("imdb", split="train")
dataset.set_transform(transform)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)

count = 0
while count < NUM_TRIES:
    for idx, batch in enumerate(train_loader): pass
    gc.collect()
    mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    print(count, mem_after - mem_before)
    count += 1

Now running it:

$ python dl-leak.py 
Reusing dataset imdb (/home/stas/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
0 4.43359375
1 4.4453125
2 4.44921875
3 4.44921875
4 4.4609375
5 4.46484375
6 4.46484375
7 4.46484375
8 4.46484375
9 4.46484375
10 4.46484375
11 4.46484375
12 4.46484375
13 4.46484375
14 4.46484375
15 4.46484375
16 4.46484375

It's normal that at the beginning there is a small growth in memory usage, but after 5 cycles it gets steady.

stas00 avatar Aug 30 '22 19:08 stas00

Unless of course you're referring the memory growth during the first try. Is that what you're referring to? And since your ds is small it's hard to see the growth - could it be just because some records are longer and it needs to allocate more memory for those?

Though while experimenting with this I have observed a peculiar thing, if I concatenate 2 datasets, I don't see any growth at all. But that's probably because the program allocated additional peak RSS memory to concatenate and then is re-using the memory

I basically tried to see if I make the dataset much longer, I'd expect not to see any memory growth once the 780 records of the imdb ds have been processed once.

stas00 avatar Aug 30 '22 21:08 stas00

It is hard to say if it is directly reproducible in this setup. Maybe it is specific to the images stored in the CM4 case which cause a memory leak. I am still running your script and seeing if I can reproduce that particular leak in this case.

apsdehal avatar Aug 30 '22 21:08 apsdehal

I was able to reproduce the leak with:

import psutil
import os
import gc
from datasets import load_from_disk
import time

DATASET_PATH = "/hf/m4-master/data/cm4/cm4-10000-v0.1"

dataset = load_from_disk(DATASET_PATH)

# truncate to a tiny dataset
dataset = dataset.select(range(1000))

print(f"dataset: {len(dataset)} records")

mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, rec in enumerate(dataset):
    if idx % 100 == 0:
        gc.collect()
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

You need to adjust the DATASET_PATH record.

which you get from

gsutil -m cp   "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset.arrow"   "gs://hf-science-m4/cm4/cm4-10000-v0.1/dataset_info.json"   "gs://hf-science-m4/cm4/cm4-10000-v0.1/state.json"   .

(I assume the hf folks have the perms) - it's a smallish dataset (10k)

then you run:

$ python ds.py
dataset: 1000 records
   0       1.0156MB
 100     126.3906MB
 200     142.8906MB
 300     168.5586MB
 400     218.3867MB
 500     230.7070MB
 600     238.9570MB
 700     263.3789MB
 800     288.1289MB
 900     300.5039MB

you should be able to see the leak

stas00 avatar Sep 07 '22 05:09 stas00

This issue has nothing to do with PIL's decoder. I removed it and the problem is still there.

I then traced this leak to this single call: pa_table.to_pydict() here:

https://github.com/huggingface/datasets/blob/08a7b389cdd6fb49264a72aa8ccfc49a233494b6/src/datasets/formatting/formatting.py#L138-L140

I can make it leak much faster by modifying that code to repeat pa_table.to_pydict() many times in a row. It shouldn't have that impact:

class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]):
    def extract_row(self, pa_table: pa.Table) -> dict:
        x = [pa_table.to_pydict() for x in range(200)]
        return _unnest(pa_table.to_pydict())

@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?

Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.

The problem already happens with pyarrow==6.0.0 or later (minimum for current datasets)

I'm also trying to dig in with objgraph to see if there are any circular references which prevent objects from being freed, but no luck there so far. And I'm pretty sure to_pydict is not a python code, so the problem is likely to happen somewhere outside of python's GC.

stas00 avatar Sep 09 '22 04:09 stas00

This appears to be the same issue I think: https://github.com/huggingface/datasets/issues/4528 I dug into the repro code there and it's the same behavior with the same leak, but it's a pure nlp dataset and thus much faster to work with.

stas00 avatar Sep 09 '22 04:09 stas00

I went all the way back to pyarrow==1.0.0 and datasets==1.12.0 and the problem is still there. How is it even possible that it wasn't noticed all this time.

Could it be that the leak is in some 3rd party component pyarrow relies on? as while downgrading I have only downgraded the above 2 packages.

stas00 avatar Sep 09 '22 05:09 stas00

Also found this warning

    Be careful: if you don't pass the ArrowArray struct to a consumer,
    array memory will leak.  This is a low-level function intended for
    expert users.

see: https://github.com/apache/arrow/blob/99b57e84277f24e8ec1ddadbb11ef8b4f43c8c89/python/pyarrow/table.pxi#L2515-L2517

perhaps something triggers this condition?

I have no idea if it's related - this is just something that came up during my research.

stas00 avatar Sep 09 '22 05:09 stas00

Does it crash with OOM at some point? If it doesn't, it isn't a leak, just agressive caching or a custom allocator that doesn't like to give memory back (not uncommon). #4528 looks like it hits a steady state.

I believe the underlying arrow libs use a custom C allocator. Some of those are designed not to give back to OS, but keep heap memory for themselves to re-use (hitting up the OS involves more expensive mutex locks, contention, etc). The greedy behaviour can be undesirable though. There are likely flags to change the allocator behaviour, and one could likely build without any custom allocators (or use a different one).

rwightman avatar Sep 09 '22 06:09 rwightman

Does it crash with OOM at some point?

In the original setup where we noticed this problem, it was indeed ending in an OOM

SaulLu avatar Sep 09 '22 08:09 SaulLu

https://github.com/huggingface/datasets/issues/4528 looks like it hits a steady state.

@rwightman in the plot I shared, the steady state comes from the time.sleep(100) I added in the end of the script, to showcase that even the garbage collector couldn't free that allocated memory.

NouamaneTazi avatar Sep 09 '22 08:09 NouamaneTazi

Could this be related to this discussion about a potential memory leak in pyarrow: https://issues.apache.org/jira/browse/ARROW-11007 ?

(Note: I've tried import pyarrow; pyarrow.jemalloc_set_decay_ms(0) and the memory leak is still happening on your toy example)

SaulLu avatar Sep 09 '22 12:09 SaulLu

@lhoestq - do you know what might be happening inside pa_table.to_pydict(), as this is in the pyarrow domain. Perhaps you know someone to tag from that project?

to_pydict calls to_pylist on each column (i.e. on each PyArrow Array). Then it iterates on the array and calls as_py on each element. The as_py implementation depends on the data type. For strings I think it simply gets the buffer that contains the binary string data that is defined in C++

The Arrow team is pretty responsive at [email protected] if it can help

Probably next need to remove datasets from the equation and make a reproducible case with just pyarrow directly.

That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?

lhoestq avatar Sep 09 '22 13:09 lhoestq

That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ?

I added you to the bucket @lhoestq

VictorSanh avatar Sep 09 '22 13:09 VictorSanh

It looks like an issue with memory mapping:

  • the amount of memory used in the end corresponds to the size of the dataset
  • setting keep_in_memory=True in load_from_disk loads the dataset in RAM, and doesn't cause any memory leak

lhoestq avatar Sep 09 '22 14:09 lhoestq

Here is a code to reproduce this issue using only PyArrow and a dummy arrow file:

import psutil
import os
import gc
import pyarrow as pa
import time

ARROW_PATH = "tmp.arrow"

if not os.path.exists(ARROW_PATH):
    arr = pa.array([b"a" * (200 * 1024)] * 1000)  # ~200MB
    table = pa.table({"a": arr})

    with open(ARROW_PATH, "wb") as f:
        writer = pa.RecordBatchStreamWriter(f, schema=table.schema)
        writer.write_table(table)
        writer.close()


def memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
    memory_mapped_stream = pa.memory_map(filename)
    opened_stream = pa.ipc.open_stream(memory_mapped_stream)
    pa_table = opened_stream.read_all()
    return pa_table


table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]

mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, x in enumerate(arr):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

prints

   0       0.2500MB
 100      19.8008MB
 200      39.3320MB
 300      58.8633MB
 400      78.3945MB
 500      97.9258MB
 600     117.4570MB
 700     136.9883MB
 800     156.5195MB
 900     176.0508MB

Note that this example simply iterates over the pyarrow.lib.BinaryScalar objects in the array. Running .as_py() is not needed to experience the memory issue.

lhoestq avatar Sep 09 '22 14:09 lhoestq

@lhoestq that does indeed increase in memory, but if you iterate over array again after the first time, or re-open and remap the same file (repeat table = memory_mapped_arrow_table_from_file(ARROW_PATH)) before re-iterating, it doesn't move pas 195MB.... it would appear another step is needed to continue consuming memory past that.. hmmm

Are the pa_tables held on to anywhere after they are iterated in the real code?

in my hack, if you do a bunch cut & paste and then change the arr name for each iter

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]

for idx, x in enumerate(arr):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr1 = table[0]

for idx, x in enumerate(arr1):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr2 = table[0]

for idx, x in enumerate(arr2):
    if idx % 100 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB")

it leaks, if all arr are the same name (so prev one gets cleaned up) it does not and goes back to 0, anything that could be holding onto a reference of an intermediary equivalent like arr in the real use case?

rwightman avatar Sep 09 '22 15:09 rwightman

Yes, we have already established here https://github.com/huggingface/datasets/issues/4883#issuecomment-1232063891 that when one iterates over the whole dataset multiple times, it consumes a bit more memory in the next few repetitions and then remains steady.

Which means that when a new iterator is created over the same dataset, all the memory from the previous iterator is re-used.

So the leak happens primarily when the iterator is "drained" the first time. which tells me that either a circular reference is created somewhere which only gets released when the iterator is destroyed, or there is some global variable that keeps piling up the memory and doesn't release it in time.

Also I noticed some __del__ methods which won't destroy objects automatically and there is usually a warning against using it https://stackoverflow.com/a/1481512/9201239

There are also some weakrefs in the code which too may lead to leaks or weird problems at times.

stas00 avatar Sep 09 '22 15:09 stas00

@stas00 my point was, I'm not convinced @lhoestq last example illustrates the leak, but rather the differences between memory mapping and in memory usage patterns. If you destroy arr, memory map impl goes back to 0 each iteration. The amount of memory that 'looks' like it is leaked in first pass differes quite a bit between memory mapped vs in memory, but the underlying issue likely a circular reference, or reference(s) which were not cleaned up that would impact either case, but likely much more visible with mmap.

rwightman avatar Sep 09 '22 16:09 rwightman

Thank you for clarifying, Ross.

I think we agree that it's almost certain that the datasets iterator traps some inner variable that prevents object freeing, since if we create the iterator multiple times (and drain it) after a few runs no new memory is allocated. We could try to dig in more with objgraph - my main concern is if the problem happens somewhere outside of python, (i.e. in pyarrow cpp implementation) in which case it'd be much more difficult to trace.

I wish there was a way on linux to tell the program to free no longer used memory at will.

stas00 avatar Sep 09 '22 16:09 stas00

FWIW, I revisted some code I had in the works to use HF datasets w/ timm train & val scripts. There is no leak there across multipe epochs. It uses the defaults.

It's worth noting that with imagenet keep_in_memory=True isn't even an option because the train arrow file is ~140GB and my local memory is less. The virtual address space reflects mmap (> 150GB) and doesn't increase over epochs that I noticed. I have some perf issues to bring up wrt to the current setup, but that's a separate and lower prio discussion to have elsewhere...

rwightman avatar Sep 09 '22 23:09 rwightman

Notes

After reading many issues and trying many things here is the summary of my learning

I'm now using @lhoestq repro case as it's pyarrow-isolated: https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985

1. pyarrow memory backends

it has 3 backends, I tried them all with the same results

pa.set_memory_pool(pa.jemalloc_memory_pool())
pa.set_memory_pool(pa.mimalloc_memory_pool())
pa.set_memory_pool(pa.system_memory_pool())

2. quick release

The jemalloc backend supports quick release

pa.jemalloc_set_decay_ms(0)

it doesn't make any difference in this case

3. actual memory allocations

this is a useful tracer for PA memory allocators

pa.log_memory_allocations(enable=True)

it nicely reports memory allocations and releases when the arrow file is created the first time.

but when we then try to do enumerate(arr) this logger reports 0 allocations.

This summary also reports no allocations when the script run the second time (post file creation):

mem_pool = pa.default_memory_pool()
print(f"PyArrow mem pool info: {mem_pool.backend_name} backend, {mem_pool.bytes_allocated()} allocated, "
              f"{mem_pool.max_memory()} max allocated, ")

print(f"PyArrow total allocated bytes: {pa.total_allocated_bytes()}")

However it's easy to see by using tracemalloc which only measures python's memory allocations that it's PA that leaks, since tracemalloc shows fixed memory

(this is bolted on top of the original repro script)

import tracemalloc
tracemalloc.start()

[...]
for idx, x in enumerate(arr):
    if idx % 10 == 0:
        gc.collect()
        time.sleep(0.1)
        mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
        mem_use = pa.total_allocated_bytes() - start_use
        mem_peak = pool.max_memory() - start_peak_use

        second_size, second_peak = tracemalloc.get_traced_memory()
        mem_diff = (second_size - first_size) / 2**20
        mem_peak_diff = (second_peak - first_peak) / 2**20

        # pa.jemalloc_memory_pool().release_unused()
        # pa.mimalloc_memory_pool().release_unused()
        # pa.system_memory_pool().release_unused()

        print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {memory_mapped_stream.size()/2**20:4.4}MB {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")

gives:

   0       5.4258MB       0.0110       0.0201 195.3MB  0.0MB  0.0MB
  10      25.3672MB       0.0112       0.0202 195.3MB  0.0MB  0.0MB
  20      45.9336MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  30      62.4336MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  40      83.0586MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  50     103.6836MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  60     124.3086MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  70     140.8086MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  80     161.4336MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB
  90     182.0586MB       0.0112       0.0203 195.3MB  0.0MB  0.0MB

the 3rd and 4th columns are tracemalloc's report.

the 5th column is the size of mmaped stream - fixed.

the last 2 are the PA's malloc reports - you can see it's totally fixed and 0.

So what gives? PA's memory allocator says nothing was allocated and we can see python doesn't allocate any memory either.

As someone suggested in one of the PA issues that IPC/GRPC could be the issue. Any suggestions on how debug this one?

The main issue is that one can't step through with a python debugger as arr is an opaque cpp object binded to python.

Please see the next comment for a possible answer.

ref-count

I also traced reference counts and they are all fixed using either sys.getrefcount(x) or len(gc.get_referrers(x))

so it's not the python object

Important related discussions

https://issues.apache.org/jira/browse/ARROW-11007 - looks very similar to our issue in particular this part of the report: https://issues.apache.org/jira/browse/ARROW-11007?focusedCommentId=17279642&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-17279642

stas00 avatar Sep 10 '22 02:09 stas00

There is no leak, just badly communicated linux RSS memory usage stats

Next, lets revisit @rwightman's suggestion that there is actually no leak.

After all - we are using mmap which will try to map the file to RAM as much as it can and then page out if there is no memory. i.e. MMAP is only fast if you have a lot of CPU RAM.

So let's do it:

Memory mapping OOM test

We first quickly start a cgroups-controlled shell which will instantly kill any program that consumes more than 1GB of memory:

$ systemd-run --user --scope -p MemoryHigh=1G -p MemoryMax=1G -p MemorySwapMax=1G --setenv="MEMLIMIT=1GB" bash

Let's check that it indeed does so. Let's change @lhoestq's script to allocate a 10GB arrow file:

$ python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 5000)'
Killed

oops, that didn't work, as we tried to allocate 10GB when only 1GB is allowed. This is what we want!

Let's do a sanity check - can we allocate 0.1GB?

python -c 'import pyarrow as pa; pa.array([b"a" * (2000 * 1024)] * 50)'

Yes. So the limited shell does the right thing. It let's allocate < 1GB of RSS RAM.

Next let's go back to @lhoestq's script but with 10GB arrow file.

we change his repro script https://github.com/huggingface/datasets/issues/4883#issuecomment-1242034985 to 50x larger file

    arr = pa.array([b"a" * (2000 * 1024)] * 5000)  # ~10000MB

we first have to run into a normal unlimited shell so that we don't get killed (as the script allocates 10GB)

let's run the script now in the 1GB-limited shell while running a monitor:

$ htop -F python -s M_RESIDENT -u `whoami`

so we have 2 sources of RSS info just in case.

$ python pyar.py
   0       4.3516MB       0.0103       0.0194 9.766e+03MB  0.0MB  0.0MB
  10      24.3008MB       0.0104       0.0195 9.766e+03MB  0.0MB  0.0MB
[...]
4980    9730.3672MB       0.0108       0.0199 9.766e+03MB  0.0MB  0.0MB
4990    9750.9922MB       0.0108       0.0199 9.766e+03MB  0.0MB  0.0MB
PyArrow mem pool info: jemalloc backend, 0 allocated, 0 max allocated,
PyArrow total allocated bytes: 0

But wait, it reported 10GB RSS both in htop and in our log!

So that means it never allocated 10GB otherwise it'd have been killed.

Which tells us that there is no leak whatsoever and this is just a really difficult situation where MMAPPED memory is reported as part of RSS which it probably shouldn't. As now we have no way how to measure real memory usage.

I also attached the script with all the different things I have tried in it, so it should be easy to turn them on/off if you want to reproduce any of my findings.

pyar.txt

just rename it to pyra.py as gh doesn't let attaching scripts...

(I have to remember to exit that special mem-limited shell or else I won't be able to do anything serious there.)

stas00 avatar Sep 10 '22 02:09 stas00

The original leak in the multi-modal code is very likely something else. But of course now it'd be very difficult to trace it using mmap.

I think to debug we have to set keep_in_memory=True in load_from_disk to load the small dataset in RAM, so there will be no mmap misleading reporting component and then continue searching for another source of a leak.

stas00 avatar Sep 10 '22 02:09 stas00

To add to what @stas00 found, I'm gonna leave some links to where I believe the confusion came from in pyarrow's APIs, for future reference:

Arrow can directly reference the data mapped from disk and avoid having to allocate its own memory.

And where their example shows 0 RSS memory allocation, the way we used to measure RSS shows 39.6719MB allocated. Here's the script to reproduce:

import psutil
import os
import gc
import pyarrow as pa
import time
import sys

# gc.set_debug(gc.DEBUG_LEAK)
# gc.set_threshold(0,0,0)

#pa.set_memory_pool(pa.mimalloc_memory_pool())
#pa.set_memory_pool(pa.system_memory_pool())

import tracemalloc

#pa.jemalloc_set_decay_ms(0)
# pa.log_memory_allocations(enable=True)

BATCH_SIZE = 10000
NUM_BATCHES = 1000
schema = pa.schema([pa.field('nums', pa.int32())])
with pa.OSFile('bigfile.arrow', 'wb') as sink:
   with pa.ipc.new_file(sink, schema) as writer:
      for row in range(NUM_BATCHES):
            batch = pa.record_batch([pa.array(range(BATCH_SIZE), type=pa.int32())], schema)
            writer.write(batch)

start_use = pa.total_allocated_bytes()
pool = pa.default_memory_pool()
start_peak_use = pool.max_memory()
tracemalloc.start()
first_size, first_peak = tracemalloc.get_traced_memory()
mem_before = psutil.Process(os.getpid()).memory_info().rss / 2**20


# with pa.OSFile('bigfile.arrow', 'rb') as source:
#    loaded_array = pa.ipc.open_file(source).read_all()

with pa.memory_map('bigfile.arrow', 'rb') as source:
   loaded_array = pa.ipc.open_file(source).read_all()


print("LEN:", len(loaded_array))
print("RSS: {}MB".format(pa.total_allocated_bytes() >> 20))

gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / 2**20
mem_use = pa.total_allocated_bytes() - start_use
mem_peak = pool.max_memory() - start_peak_use
second_size, second_peak = tracemalloc.get_traced_memory()
mem_diff = (second_size - first_size) / 2**20
mem_peak_diff = (second_peak - first_peak) / 2**20

idx = 0
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB {mem_diff:12.4f} {mem_peak_diff:12.4f} {mem_use/2**20:4.4}MB {mem_peak/2**20:4.4}MB")

gives:


LEN: 10000000
RSS: 0MB
   0      39.6719MB       0.0132       0.0529  0.0MB  0.0MB

Which again just proves that we uncorrectly measure RSS, in the case of MMAPPED memory

NouamaneTazi avatar Sep 12 '22 08:09 NouamaneTazi

@lhoestq, I have been working on a detailed article that shows that MMAP doesn't leak and it's mostly ready. I will share when it's ready.

The issue is that we still need to be able to debug memory leaks by turning MMAP off.

But, once I tried to show the user that using load_dataset(... keep_in_memory=True) is the way to debug an actual memory leak - guess I what I discovered? A potential actual leak.

Here is the repro:

$ cat ds-mmap.py
from datasets import load_dataset
import gc
import os
import psutil

proc = psutil.Process(os.getpid())
def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

dataset = load_dataset("wmt19", 'cs-en', keep_in_memory=True, streaming=False)['train']

print(f"{'idx':>6} {'RSS':>10} {'Δ RSS':>15}")
step = 20000
for i in range(0, 10*step, step):
    mem_before = mem_read()
    _ = dataset[i:i+step]
    mem_after = mem_read()
    print(f"{i:6d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB ")
python ds-io.py
Reusing dataset wmt19 (/home/stas/.cache/huggingface/datasets/wmt19/cs-en/1.0.0/c3db1bf4240362ed1ef4673b354f468d70aac66d4e67d45f536d493a0840f0d3)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.66it/s]
   idx        RSS           Δ RSS
     0    1398.4609MB       3.5195MB
 20000    1398.5742MB       0.1133MB
 40000    1398.6016MB       0.0273MB
 60000    1398.6016MB       0.0000MB
 80000    1398.6016MB       0.0000MB
100000    1398.6328MB       0.0312MB
120000    1398.6953MB       0.0625MB
140000    1398.6953MB       0.0000MB
160000    1398.7500MB       0.0547MB
180000    1398.7500MB       0.0000MB

stas00 avatar Sep 15 '22 21:09 stas00

as I suggested on slack perhaps it was due to dataset records length variation, so with your help I wrote another repro with synthetic records which are all identical - which should remove my hypothese from the equation and we should expect 0 incremental growth as we iterate over the datasets. But alas this is not the case. There is a tiny but definite leak-like behavior.

Here is the new repro:

$ cat ds-synthetic-no-mmap.py
from datasets import load_from_disk, Dataset
import gc
import sys
import os
import psutil

proc = psutil.Process(os.getpid())
def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

DS_PATH = "synthetic-ds"
if not os.path.exists(DS_PATH):
    records = 1_000_000
    print("Creating a synthetic dataset")
    row = dict(foo=[dict(a='a'*500, b='b'*1000)])
    ds = Dataset.from_dict({k: [v] * records for k, v in row.items()})
    ds.save_to_disk(DS_PATH)
    print("Done. Please restart the program")
    sys.exit()

dataset = load_from_disk(DS_PATH, keep_in_memory=True)
print(f"Dataset len={len(dataset)}")

print(f"{'idx':>8} {'RSS':>10} {'Δ RSS':>15}")
mem_start = 0
step = 25_000
warmup_iterations = 4
for idx, i in enumerate(range(0, len(dataset), step)):
    if idx == warmup_iterations: # skip the first few iterations while things get set up
        mem_start = mem_read()
    mem_before = mem_read()
    _ = dataset[i:i+step]
    mem_after = mem_read()
    print(f"{i:8d} {mem_after:12.4f}MB {mem_after - mem_before:12.4f}MB")
mem_end = mem_read()

print(f"Total diff: {mem_end - mem_start:12.4f}MB (after {warmup_iterations} warmup iterations)")

and the run:

$ python ds-synthetic-no-mmap.py
Dataset len=1000000
     idx        RSS           Δ RSS
       0    1601.9258MB      47.9688MB
   25000    1641.6289MB      39.7031MB
   50000    1641.8594MB       0.2305MB
   75000    1642.1289MB       0.2695MB
  100000    1642.1289MB       0.0000MB
  125000    1642.3789MB       0.2500MB
  150000    1642.3789MB       0.0000MB
  175000    1642.6289MB       0.2500MB
  200000    1642.6289MB       0.0000MB
  225000    1642.8789MB       0.2500MB
  250000    1642.8828MB       0.0039MB
  275000    1643.1328MB       0.2500MB
  300000    1643.1328MB       0.0000MB
  325000    1643.3828MB       0.2500MB
  350000    1643.3828MB       0.0000MB
  375000    1643.6328MB       0.2500MB
  400000    1643.6328MB       0.0000MB
  425000    1643.8828MB       0.2500MB
  450000    1643.8828MB       0.0000MB
  475000    1644.1328MB       0.2500MB
  500000    1644.1328MB       0.0000MB
  525000    1644.3828MB       0.2500MB
  550000    1644.3828MB       0.0000MB
  575000    1644.6328MB       0.2500MB
  600000    1644.6328MB       0.0000MB
  625000    1644.8828MB       0.2500MB
  650000    1644.8828MB       0.0000MB
  675000    1645.1328MB       0.2500MB
  700000    1645.1328MB       0.0000MB
  725000    1645.3828MB       0.2500MB
  750000    1645.3828MB       0.0000MB
  775000    1645.6328MB       0.2500MB
  800000    1645.6328MB       0.0000MB
  825000    1645.8828MB       0.2500MB
  850000    1645.8828MB       0.0000MB
  875000    1646.1328MB       0.2500MB
  900000    1646.1328MB       0.0000MB
  925000    1646.3828MB       0.2500MB
  950000    1646.3828MB       0.0000MB
  975000    1646.6328MB       0.2500MB
Total diff:       4.5039MB (after 4 warmup iterations)

so I'm still not sure why we get this.

As you can see I started skipping the first few iterations where memory isn't stable yet. As the actual diff is much larger if we count all iterations.

What do you think?

stas00 avatar Sep 15 '22 21:09 stas00

@stas00 my 2 cents from having looked at a LOT of memory leaks over the years, esp in Python, .3% memory increase over that many iterations of something is difficult to say with certainty it is a leak.

Also, just looking at RSS makes it hard to analyze leaks. RSS can stay near constant while you are leaking. RSS is paged in mem, if you have a big leak your RSS might not increase much (leaked mem tends not to get used again so often paged out) while your virtual page allocation could be going through the roof...

rwightman avatar Sep 15 '22 22:09 rwightman

yes, that's true, but unless the leak is big, I'm yet to find another measurement tool.

To prove your point here is a very simple IO in a loop program that also reads the same line all over again:

$ cat mmap-no-leak-debug.py
import gc
import mmap
import os
import psutil
import sys

proc = psutil.Process(os.getpid())

PATH = "./tmp.txt"

def mem_read():
    gc.collect()
    return proc.memory_info().rss / 2**20

# create a large data file with a few long lines
if not os.path.exists(PATH):
    with open(PATH, "w") as fh:
        s = 'a'* 2**27 + "\n" # 128MB
        # write ~2GB file
        for i in range(16):
            fh.write(s)

print(f"{'idx':>4} {'RSS':>10}   {'Δ RSS':>12}   {'Δ accumulated':>10}")

total_read = 0
content = ''
mem_after = mem_before_acc = mem_after_acc = mem_before = proc.memory_info().rss / 2**20
print(f"{0:4d} {mem_after:10.2f}MB {mem_after - 0:10.2f}MB {0:10.2f}MB")

mmap_mode = True if "--mmap" in sys.argv else False

with open(PATH, "r") as fh:

    if mmap_mode:
        mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)

    idx = 0
    while True:
        idx += 1
        mem_before = mem_read()
        line = mm.readline() if mmap_mode else fh.readline()
        if not line:
            break

        #total_read += len(line)

        if "--accumulate" in sys.argv:
            mem_before_acc = mem_read()
            content += str(line)
            mem_after_acc = mem_read()

        mem_after = mem_read()

        print(f"{idx:4d} {mem_after:10.2f}MB {mem_after - mem_before:10.2f}MB {mem_after_acc - mem_before_acc:10.2f}MB")

it has some other instrumentations to do mmap and accumulate data, but let's ignore that for now.

Here it is running in a simple non-mmap IO:

$ python mmap-no-leak-debug.py
 idx        RSS          Δ RSS   Δ accumulated
   0      12.43MB      12.43MB       0.00MB
   1     269.72MB     257.29MB       0.00MB
   2     269.73MB       0.02MB       0.00MB
   3     269.73MB       0.00MB       0.00MB
   4     269.74MB       0.01MB       0.00MB
   5     269.74MB       0.00MB       0.00MB
   6     269.75MB       0.01MB       0.00MB
   7     269.75MB       0.00MB       0.00MB
   8     269.76MB       0.01MB       0.00MB
   9     269.76MB       0.00MB       0.00MB
  10     269.77MB       0.01MB       0.00MB
  11     269.77MB       0.00MB       0.00MB
  12     269.77MB       0.00MB       0.00MB
  13     269.77MB       0.00MB       0.00MB
  14     269.77MB       0.00MB       0.00MB
  15     269.77MB       0.00MB       0.00MB
  16     146.02MB    -123.75MB       0.00MB

as you can see even this super-simplistic program that just performs readline() slightly increases in RSS over iterations.

If you have a better tool for measurement other than RSS, I'm all ears.

stas00 avatar Sep 15 '22 23:09 stas00

@stas00 if you aren't using memory maps, you should be able to clearly see the increase in the virtual mem for the process as well. Even then, it could still be challenging to determine if it's leak vs fragmentation due to problematic allocation patterns (not uncommon with Python). Using a better mem allocator like tcmalloc via LD_PRELOAD hooks could reduce impact of fragmentation across both Python and c libs. Not sure that plays nice with any allocator that arrow might use itself though.

rwightman avatar Sep 15 '22 23:09 rwightman