Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] GPU Memory is leaked when HookedTransformer goes out of scope.
I was iterating some operations over the Pythia model suite and realized that each time I loaded a model, the GPU memory usage would increase. Because the memory usage as reported by torch and nvidia-smi is several GB, I'm guessing there are some arrays that are attached to a cache or class rather than to an instance of HookedTransformer. I would be happy to debug when I get a chance. For now, I'm able to just use transformers without TransformerLens. I was originally just using HookedTransformer out of habit.
Here's a minimal reproducing example:
import os
import time
import transformers
import torch
from transformer_lens import HookedTransformer
model_param_str = ['70m', '160m', '410m']
model_names = {mps: f"EleutherAI/pythia-{mps}-deduped" for mps in model_param_str}
# GPU memory usage is 0 (as checked by nvidia-smi) after each model is loaded and garbage collected.
for param_str, model_name in model_names.items():
model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)
model.cuda()
del model
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())
# Give the GPU some time to update its memory usage.
time.sleep(1)
os.system('nvidia-smi')
# GPU memory usage accumulates from past models.
for param_str, model_name in model_names.items():
model = HookedTransformer.from_pretrained(model_name)
del model
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())
time.sleep(1)
os.system('nvidia-smi')
and the output is:
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 0 B | 1642 MiB | 2608 MiB | 2608 MiB |
| from large pool | 0 B | 1641 MiB | 2600 MiB | 2600 MiB |
| from small pool | 0 B | 6 MiB | 7 MiB | 7 MiB |
|---------------------------------------------------------------------------|
| Active memory | 0 B | 1642 MiB | 2608 MiB | 2608 MiB |
| from large pool | 0 B | 1641 MiB | 2600 MiB | 2600 MiB |
| from small pool | 0 B | 6 MiB | 7 MiB | 7 MiB |
|---------------------------------------------------------------------------|
| Requested memory | 0 B | 1642 MiB | 2602 MiB | 2602 MiB |
| from large pool | 0 B | 1641 MiB | 2594 MiB | 2594 MiB |
| from small pool | 0 B | 6 MiB | 7 MiB | 7 MiB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 0 B | 1664 MiB | 1664 MiB | 1664 MiB |
| from large pool | 0 B | 1656 MiB | 1656 MiB | 1656 MiB |
| from small pool | 0 B | 8 MiB | 8 MiB | 8 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 0 B | 207716 KiB | 2128 MiB | 2128 MiB |
| from large pool | 0 B | 206336 KiB | 2114 MiB | 2114 MiB |
| from small pool | 0 B | 4053 KiB | 13 MiB | 13 MiB |
|---------------------------------------------------------------------------|
| Allocations | 0 | 364 | 642 | 642 |
| from large pool | 0 | 122 | 210 | 210 |
| from small pool | 0 | 242 | 432 | 432 |
|---------------------------------------------------------------------------|
| Active allocs | 0 | 364 | 642 | 642 |
| from large pool | 0 | 122 | 210 | 210 |
| from small pool | 0 | 242 | 432 | 432 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 0 | 53 | 53 | 53 |
| from large pool | 0 | 49 | 49 | 49 |
| from small pool | 0 | 4 | 4 | 4 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 0 | 12 | 159 | 159 |
| from large pool | 0 | 9 | 106 | 106 |
| from small pool | 0 | 5 | 53 | 53 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
Tue May 30 22:15:35 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla V100-PCIE-16GB On | 00000001:00:00.0 Off | Off |
| N/A 28C P0 41W / 250W| 370MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-410m-deduped into HookedTransformer
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 2699 MiB | 6969 MiB | 13621 MiB | 10921 MiB |
| from large pool | 2663 MiB | 6895 MiB | 13477 MiB | 10813 MiB |
| from small pool | 36 MiB | 80 MiB | 144 MiB | 107 MiB |
|---------------------------------------------------------------------------|
| Active memory | 2699 MiB | 6969 MiB | 13621 MiB | 10921 MiB |
| from large pool | 2663 MiB | 6895 MiB | 13477 MiB | 10813 MiB |
| from small pool | 36 MiB | 80 MiB | 144 MiB | 107 MiB |
|---------------------------------------------------------------------------|
| Requested memory | 2698 MiB | 6958 MiB | 13600 MiB | 10901 MiB |
| from large pool | 2662 MiB | 6884 MiB | 13456 MiB | 10794 MiB |
| from small pool | 36 MiB | 80 MiB | 144 MiB | 107 MiB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 3120 MiB | 7078 MiB | 8742 MiB | 5622 MiB |
| from large pool | 3072 MiB | 6996 MiB | 8652 MiB | 5580 MiB |
| from small pool | 48 MiB | 82 MiB | 90 MiB | 42 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 430167 KiB | 610561 KiB | 9217 MiB | 8797 MiB |
| from large pool | 418176 KiB | 599936 KiB | 9072 MiB | 8664 MiB |
| from small pool | 11991 KiB | 13838 KiB | 144 MiB | 132 MiB |
|---------------------------------------------------------------------------|
| Allocations | 430 | 1503 | 3658 | 3228 |
| from large pool | 307 | 706 | 1306 | 999 |
| from small pool | 123 | 797 | 2352 | 2229 |
|---------------------------------------------------------------------------|
| Active allocs | 430 | 1503 | 3658 | 3228 |
| from large pool | 307 | 706 | 1306 | 999 |
| from small pool | 123 | 797 | 2352 | 2229 |
|---------------------------------------------------------------------------|
| GPU reserved segments | 141 | 252 | 305 | 164 |
| from large pool | 117 | 211 | 260 | 143 |
| from small pool | 24 | 41 | 45 | 21 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 178 | 182 | 1528 | 1350 |
| from large pool | 80 | 84 | 481 | 401 |
| from small pool | 98 | 100 | 1047 | 949 |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
Tue May 30 22:15:48 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla V100-PCIE-16GB On | 00000001:00:00.0 Off | Off |
| N/A 28C P0 41W / 250W| 3500MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
Thanks Ben. Much appreciated. I'll look into it and see what I can find.
Hi @tbenthompson, I've reproduced this on TransformerLens v1.2.2
What version of TransformerLens did you use to produce this? I'm just trying to narrow down how long this has been an issue for
Great! Glad the repro works! I produced this on 1.2.2.
Have been investigating this with @pranavgade20 today. We haven't got to the bottom of the issue, but leaving some notes here for anybody to pick up
Circular reference
The HookedTransformer contains a circular reference, as proven in this snippet:
model = HookedTransformer.from_pretrained("gpt2-small")
assert model.mod_dict[""] == model
We attempted to patch this with the following patch, which succeeded in removing the circular reference, but surprisingly did not have any effect on the memory usage:
diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py
index 2802943..dd4f93b 100644
--- a/transformer_lens/hook_points.py
+++ b/transformer_lens/hook_points.py
@@ -139,10 +139,11 @@ class HookedRootModule(nn.Module):
self.mod_dict = {}
self.hook_dict: Dict[str, HookPoint] = {}
for name, module in self.named_modules():
- module.name = name
- self.mod_dict[name] = module
- if "HookPoint" in str(type(module)):
- self.hook_dict[name] = module
+ if name != "":
+ module.name = name
+ self.mod_dict[name] = module
+ if "HookPoint" in str(type(module)):
+ self.hook_dict[name] = module
def hook_points(self):
return self.hook_dict.values()
Dangling parameters
The following snippet shows that many tensors are still in scope after deleting the model and running the garbage collector:
import torch
import gc
from pympler import refbrowser
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small")
del model
gc.collect()
torch.cuda.empty_cache()
count = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
count += 1
print(type(obj), obj.size())
cb = refbrowser.ConsoleBrowser(obj)
cb.print_tree()
except Exception as e:
pass
print(f"Found {count} dangling objects")
Output:
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
torch.nn.parameter.Parameter-+-dict-+-module(__main__)-+-dict
| | +-dict
| | +-dict
| | +-dict
| | +-list
| |
| +-list--list_iterator
|
+-list--list_iterator
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
torch.nn.parameter.Parameter-+-dict-+-module(__main__)-+-dict
| | +-dict
| | +-dict
| | +-dict
| | +-list
| |
| +-list--list_iterator
|
+-list--list_iterator
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([50257, 768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
...
Possible next steps
- dig into the ConsoleBrowser to find out where the dangling references are coming from
- a useful pattern to use here is
refbrowser.ConsoleBrowser(obj, str_func=lambda x: pdb.set_trace() or str(x)).print_tree()
, which will launch the debugger inside the ConsoleBrowser
- a useful pattern to use here is
- maybe try to to reproduce all of this on CPU in order to work locally and maybe use
refbrowser.InteractiveBrowser
, which might make debugging easier - try to reproduce on older versions of TransformerLens
- strace
- ???
I couldn't repro the issue on CPU on OS X M1. Confusingly, the 410m TransformerLens model seems to be using up basically no RAM.
EDIT: I ran it for longer and noticed that swap usage went up to 30GB, so seems like the issue does occur on CPU as well!
from transformer_lens import HookedTransformer
import transformers
import psutil
from time import sleep
import gc
process = psutil.Process()
print(f"PID: {process.pid}")
def print_memory_usage():
print(f"Memory usage: {process.memory_info().rss / 1024 // 1024} MiB")
model_param_str = ['70m', '160m', '410m']
model_names = {mps: f"EleutherAI/pythia-{mps}-deduped" for mps in model_param_str}
print_memory_usage()
print("Loading huggingface models")
for param_str, model_name in model_names.items():
model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)
sleep(1)
print_memory_usage()
print("Cleaning up huggingface models")
del model
gc.collect()
sleep(1)
print_memory_usage()
print("Loading TransformerLens models")
for param_str, model_name in model_names.items():
model = HookedTransformer.from_pretrained(model_name, device="cpu")
sleep(1)
print_memory_usage()
print("Cleaning up TransformerLens models")
del model
gc.collect()
sleep(1)
print_memory_usage()
Output:
PID: 23096
Memory usage: 270.0 MiB
Loading huggingface models
Memory usage: 2051.0 MiB
Cleaning up huggingface models
Memory usage: 992.0 MiB
Loading TransformerLens models
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-410m-deduped into HookedTransformer
Memory usage: 1019.0 MiB
Cleaning up TransformerLens models
Memory usage: 852.0 MiB
OK, after playing around with various tools, I ran the fil profiler. As you can see in the screenshot below, the process is allocating nearly 15GB of memory to load GPT2 10 times in a loop (compared to around 2.7GB to load the model once).
I'm pretty sure the tool is telling us which lines of code are responsible for the majority of memory allocation at the peak usage! For example it shows that the line state_dict[k] = v.to(
was responsible for 3GB or 25% of the allocated memory at peak usage when I ran this script. (Note: the line numbers reported by the tool seem to be slightly out of sync so i'm not yet sure exactly which occurrence of this expression is responsible.)
It seems like the issues are mostly related to creating new tensors, so maybe we are doing something wrong when we create them (e.g. not detaching them from the computational graph??)
If anyone wants to look into this with me tomorrow then lmk!
Screenshot
I've included screenshot here, but this is best viewed interactively. Download this folder and index.html to do that: https://www.file.io/dCWH/download/LnE5aU1j0Ftt
Steps to repro
Code
# detect-leak.py
from transformer_lens import HookedTransformer
from tqdm import tqdm
for _ in tqdm(range(10)):
model = HookedTransformer.from_pretrained("gpt2-small", device="cpu")
Command
pip install filprofiler
fil-profile run detect-leak.py
Update! I've managed to get the memory usage down from 15GiB to 5GiB in my test.
If you compare the diagram below with the one from my previous post, all of the overhead from load_and_process_state_dict
is now gone for GPT2. The cause was that we weren't calling detach()
when adding the weights from the reference model to the state dict!
I will raise a PR for this. The next step is to tackle the next-worst offender: move_model_modules_to_device
Sweet!! Glad you're getting to the root of the issue. This will be very helpful for anyone using TransformerLens with larger models.
Unfortunately when I reran @tbenthompson's reproducing example, the GPU usage is unchanged. Seems like this is a real issue, but it might be a separate one
Oh man, sorry about that! (and the other inevitable bugs in the guts of core TL code...) I appreciate the heroic effort!
On Sun, 4 Jun 2023, 6:05 pm Rusheb Shah, @.***> wrote:
Unfortunately when I reran @tbenthompson https://github.com/tbenthompson's reproducing example, the GPU usage is unchanged. Seems like this is a real issue, but it might be a separate one
— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/290#issuecomment-1575640719, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKPIVWB72R5CEEWFLJ3XJS54JANCNFSM6AAAAAAYUT4MQE . You are receiving this because you are subscribed to this thread.Message ID: @.***>
I think this is mostly because py/torch.cuda is being lazy and not freeing up memory because it doesn't need it. You can force it to free memory by limiting the amount of memory it has with something like this:
import torch
torch.cuda.set_per_process_memory_fraction(0.4, 0) # adjust this fraction to correspond to ~2.5gigs of vram
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small")
gc.collect()
gc.collect() # two calls are sometimes necessary because of unfortunate timing
torch.cuda.empty_cache()
print("===================================== peak usage")
os.system('nvidia-smi') # this says 1.7 GiB is taken up
This uses ~2322MiB of vram on my 3060. Removing the set_per_process_memory_fraction
call takes up ~3050MiB.
This is probably not the whole story, but some of it does appear to be just cuda not freeing up memory because it is slightly faster.
Speculation:
If I set gc.set_debug(gc.DEBUG_STATS)
, I can see a lot of gc passes in the young generation, but since the program footprint in RAM is just a few megabytes, it almost never tries to clean up the old generation. This means that the pointer to a GiB in VRAM is a handful of bytes, and thus not a good target for getting cleaned up.
Possibly unrelated, I can get torch.cuda to release t = torch.zeros((2000,1000,1000), device='cuda', dtype=torch.uint8)
2 GB of memory - it should be just 1.3 GB if we consider the peak usage after deleting the model, so doing this does force some more memory to be released. If I delete t
immediately after it is allocated, it does get cleaned up, probably because it is still in the young generation.
Since this causes the memory usage to drop down to ~1 GiB for gpt-2, this is probably the best patch to force torch to free at least some memory until more details are figured out.