Memory leak when using JaccardIndex during multi-gpu training
🐛 Bug
When using JaccardIndex as a metric during training on multiple GPUs, a memory leak causes the training to crash.
To Reproduce
The problem is caused when returning the predictions and targets in the training_step function
return {"loss": loss, 'preds': logits.detach(), 'target': labels_wo_unknown.detach()}
which is then used in training_step_end to update the metric
self.train_iou(outputs['preds'], outputs['target'])
and finally gather the metrics across all GPUs in training_epoch_end
class_ious = self.train_iou.compute()
The memory leak does not occur when sending the results to the CPU before returning them in training_step
return {"loss": loss, 'preds': logits.detach().cpu(), 'target': labels_wo_unknown.detach().cpu()}
However, this of course crashes the metric computation as it is expected to run on the GPU.
I presume that the problem is that pytorch lightning is not correctly handling the results (i.e. garbage collection).
Do you have any insights into this bug? This behaviour makes torchmetrics (at least JaccardIndex) not usable for multi-gpu training.
The same behaviour is also present in the validation part of the pipeline.
Hi! thanks for your contribution!, great first issue!
Example to reproduce
# main.py
import sys
import torch
import torch.utils.data
import torch.distributed
import torchmetrics
import tqdm
class Dataset(torch.utils.data.Dataset):
def __init__(self, K, H, W):
self.K = K
self.H = H
self.W = W
def __len__(self):
return 1000
def __getitem__(self, item):
# Semantic segmentation mock data
pred = torch.randn((self.K, self.H, self.W))
target = torch.randint(self.K, (self.H, self.W))
return pred, target
def main():
rank = int(sys.argv[1])
device = torch.device(f'cuda:{rank}')
torch.distributed.init_process_group(
backend='nccl',
init_method='tcp://localhost:40123',
world_size=2,
rank=rank,
)
print('Dist:', torch.distributed.get_rank(), torch.distributed.get_world_size())
dl_train = torch.utils.data.DataLoader(
Dataset(K=200, H=256, W=256),
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True,
)
dl_val = torch.utils.data.DataLoader(
Dataset(K=200, H=256, W=256),
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=False,
)
for epoch in range(10):
miou = torchmetrics.JaccardIndex(num_classes=200, average='macro').to(device)
with tqdm.tqdm(dl_train, ncols=0) as bar:
for i, (pred, target) in enumerate(bar):
pred = pred.to(device)
target = target.to(device)
miou.update(pred, target)
miou.compute()
bar.set_postfix_str(f'mem={torch.cuda.max_memory_allocated(device)}')
del miou
miou = torchmetrics.JaccardIndex(num_classes=200, average='macro').to(device)
with tqdm.tqdm(dl_val, ncols=0) as bar:
for i, (pred, target) in enumerate(bar):
pred = pred.to(device)
target = target.to(device)
miou.update(pred, target)
bar.set_postfix_str(f'mem={torch.cuda.max_memory_allocated(device)}')
miou.compute()
del miou
if __name__ == '__main__':
main()
Launch two processes:
python main.py 0
python main.py 1
Output from rank 0 (similar output on rank 1):
Dist: 0 2
100% 31/31 [00:36<00:00, 1.18s/it, mem=8406025216]
100% 32/32 [00:34<00:00, 1.06s/it, mem=8406665728]
100% 31/31 [00:34<00:00, 1.11s/it, mem=8407306240]
100% 32/32 [00:35<00:00, 1.10s/it, mem=8407946752]
100% 31/31 [00:33<00:00, 1.09s/it, mem=8408587264]
100% 32/32 [00:33<00:00, 1.06s/it, mem=8409227776]
Memory usage slowly increases after every "epoch".
Hi @weders, Thanks for raising this issue. Based on the example you have provided, it would seem the problem indeed is with torchmetrics and have nothing to do with lightning. I try to look into what is going on.
@SkafteNicki did you succeed in debugging this issue? :otter:
@Borda I was unable to execute the code on my local cluster. Will try again.
I ran another test on a different machine.
System:
- NVIDIA Driver Version: 470.86
- CUDA Version: 11.4
- 2x NVIDIA GeForce GTX 1080 Ti
Setup:
conda create -y -n tmp -c pytorch -c conda-forge pytorch torchvision cudatoolkit=11.6 tqdm
conda activate tmp
pip install torchmetrics
conda env export > environment.yaml
Launch:
python main.py 0
python main.py 1
Outputs:
# python main.py 0
Dist: 0 2
100% 31/31 [00:53<00:00, 1.72s/it, mem=8406025216]
100% 32/32 [00:48<00:00, 1.53s/it, mem=8406665728]
100% 31/31 [00:49<00:00, 1.60s/it, mem=8407306240]
100% 32/32 [00:50<00:00, 1.59s/it, mem=8407946752]
100% 31/31 [00:50<00:00, 1.63s/it, mem=8407946752]
100% 32/32 [00:49<00:00, 1.56s/it, mem=8407946752]
100% 31/31 [00:50<00:00, 1.62s/it, mem=8407946752]
100% 32/32 [00:50<00:00, 1.57s/it, mem=8408587264]
100% 31/31 [00:51<00:00, 1.66s/it, mem=8409227776]
100% 32/32 [00:48<00:00, 1.51s/it, mem=8409868288]
# python main.py 1
Dist: 1 2
100% 31/31 [00:53<00:00, 1.71s/it, mem=8406025216]
100% 32/32 [00:50<00:00, 1.58s/it, mem=8406665728]
100% 31/31 [00:49<00:00, 1.60s/it, mem=8407306240]
100% 32/32 [00:50<00:00, 1.57s/it, mem=8407946752]
100% 31/31 [00:50<00:00, 1.63s/it, mem=8407946752]
100% 32/32 [00:49<00:00, 1.55s/it, mem=8407946752]
100% 31/31 [00:50<00:00, 1.62s/it, mem=8407946752]
100% 32/32 [00:49<00:00, 1.54s/it, mem=8408587264]
100% 31/31 [00:51<00:00, 1.66s/it, mem=8409227776]
100% 32/32 [00:52<00:00, 1.64s/it, mem=8409868288]
Attachments:
@justusschock, mind having a look at it? :otter:
Hi @baldassarreFe I finally was able to deep dive into this issue and can reproduce the error with the code you have provided. Thanks for that :]
The underlying issue here is how garbage collection work in python. When you del miou in the code it does not immediately delete the metric object, it only marks it ready to be deleted. If you really want to free up the memory you need to do:
del miou
gc.collect()
using the garbage collection module. Additionally, since you are running on GPU you will need to call torch.cuda.empty_cache() to release the memory for other use. Therefore, in total do:
del miou
gc.collect()
torch.cuda.empty_cache()
That said, there is a much simpler approach for your case, by just reusing the same metric. This is an adjusted version of you main function using metric.reset() instead that have constant memory usage:
def main():
... # do distributed and data loading stuff
# initialize metrice once
miou = torchmetrics.JaccardIndex(num_classes=200, average='macro').to(device)
for epoch in range(10):
with tqdm.tqdm(dl_train, ncols=0) as bar:
for i, (pred, target) in enumerate(bar):
pred = pred.to(device)
target = target.to(device)
miou.update(pred, target)
miou.compute()
miou.reset() # reset
with tqdm.tqdm(dl_val, ncols=0) as bar:
for i, (pred, target) in enumerate(bar):
pred = pred.to(device)
target = target.to(device)
miou.update(pred, target)
miou.compute()
miou.reset() # reset
Hope this clears up any confusion :]
Hello @SkafteNicki, thanks for looking into the issue and for the proposed solution. However, I wonder if there is some underlying bug with memory management that we are overlooking here.
Here's why:
-
As far as I understand, python's garbage collection and GPU memory caching simply mark objects for deletion. When memory is needed, marked objects are collected. With this behavior we should observe memory consumption increase but we should never get OOM errors. Instead I've had many runs crash due to OOM.
-
The memory increase and later OOM errors only happen in a distributed setting. In single-process mode memory is correctly freed and reused, but multiprocessing interferes with that somehow.
If torchmetrics is not doing anything weird with memory management in DDP, can it be a bug with pytorch?
It may very well be something on Pytorch side on how they interact with the garbage collector. I can also get the code working if I do something like:
del miou._defaults
del miou.confmat
del miou
where I go in and explicit delete the tensors that are not cleaned up by the garbage collector normally. This of cause requires some knowledge about the internals of torchmetrics.
All right, it's probably related to pytorch then. I don't have enough knowledge of pytorch's internals to open an issue there, but let's hope it gets fixed in future releases. Meanwhile, I'll work around the OOM errors using your suggestions, thanks!
@baldassarreFe you are welcome, always happy to help and sorry it took so long to debug. I try to keep this in mind if we ever run into other problems regarding OOM. Closing issue.