torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Memory leak when using JaccardIndex during multi-gpu training

Open weders opened this issue 3 years ago • 3 comments

🐛 Bug

When using JaccardIndex as a metric during training on multiple GPUs, a memory leak causes the training to crash.

To Reproduce

The problem is caused when returning the predictions and targets in the training_step function

return {"loss": loss, 'preds': logits.detach(), 'target': labels_wo_unknown.detach()}

which is then used in training_step_end to update the metric

self.train_iou(outputs['preds'], outputs['target'])

and finally gather the metrics across all GPUs in training_epoch_end

class_ious = self.train_iou.compute()

The memory leak does not occur when sending the results to the CPU before returning them in training_step

return {"loss": loss, 'preds': logits.detach().cpu(), 'target': labels_wo_unknown.detach().cpu()}

However, this of course crashes the metric computation as it is expected to run on the GPU.

I presume that the problem is that pytorch lightning is not correctly handling the results (i.e. garbage collection).

Do you have any insights into this bug? This behaviour makes torchmetrics (at least JaccardIndex) not usable for multi-gpu training.

The same behaviour is also present in the validation part of the pipeline.

weders avatar Aug 22 '22 14:08 weders

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Aug 22 '22 14:08 github-actions[bot]

Example to reproduce

# main.py
import sys
import torch
import torch.utils.data
import torch.distributed
import torchmetrics
import tqdm


class Dataset(torch.utils.data.Dataset):
    def __init__(self, K, H, W):
        self.K = K
        self.H = H
        self.W = W

    def __len__(self):
        return 1000

    def __getitem__(self, item):
        # Semantic segmentation mock data
        pred = torch.randn((self.K, self.H, self.W))
        target = torch.randint(self.K, (self.H, self.W))
        return pred, target


def main():
    rank = int(sys.argv[1])
    device = torch.device(f'cuda:{rank}')

    torch.distributed.init_process_group(
        backend='nccl',
        init_method='tcp://localhost:40123',
        world_size=2,
        rank=rank,
    )
    print('Dist:', torch.distributed.get_rank(), torch.distributed.get_world_size())

    dl_train = torch.utils.data.DataLoader(
        Dataset(K=200, H=256, W=256),
        batch_size=32,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
    )
    dl_val = torch.utils.data.DataLoader(
        Dataset(K=200, H=256, W=256),
        batch_size=32,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
    )
    for epoch in range(10):
        miou = torchmetrics.JaccardIndex(num_classes=200, average='macro').to(device)
        with tqdm.tqdm(dl_train, ncols=0) as bar:
            for i, (pred, target) in enumerate(bar):
                pred = pred.to(device)
                target = target.to(device)
                miou.update(pred, target)
                miou.compute()
                bar.set_postfix_str(f'mem={torch.cuda.max_memory_allocated(device)}')
        del miou

        miou = torchmetrics.JaccardIndex(num_classes=200, average='macro').to(device)
        with tqdm.tqdm(dl_val, ncols=0) as bar:
            for i, (pred, target) in enumerate(bar):
                pred = pred.to(device)
                target = target.to(device)
                miou.update(pred, target)
                bar.set_postfix_str(f'mem={torch.cuda.max_memory_allocated(device)}')
        miou.compute()
        del miou


if __name__ == '__main__':
    main()

Launch two processes:

python main.py 0
python main.py 1

Output from rank 0 (similar output on rank 1):

Dist: 0 2
100% 31/31 [00:36<00:00,  1.18s/it, mem=8406025216]
100% 32/32 [00:34<00:00,  1.06s/it, mem=8406665728]
100% 31/31 [00:34<00:00,  1.11s/it, mem=8407306240]
100% 32/32 [00:35<00:00,  1.10s/it, mem=8407946752]
100% 31/31 [00:33<00:00,  1.09s/it, mem=8408587264]
100% 32/32 [00:33<00:00,  1.06s/it, mem=8409227776]

Memory usage slowly increases after every "epoch".

baldassarreFe avatar Aug 24 '22 11:08 baldassarreFe

Hi @weders, Thanks for raising this issue. Based on the example you have provided, it would seem the problem indeed is with torchmetrics and have nothing to do with lightning. I try to look into what is going on.

SkafteNicki avatar Aug 30 '22 22:08 SkafteNicki

@SkafteNicki did you succeed in debugging this issue? :otter:

Borda avatar Oct 19 '22 12:10 Borda

@Borda I was unable to execute the code on my local cluster. Will try again.

SkafteNicki avatar Oct 25 '22 12:10 SkafteNicki

I ran another test on a different machine.

System:

  • NVIDIA Driver Version: 470.86
  • CUDA Version: 11.4
  • 2x NVIDIA GeForce GTX 1080 Ti

Setup:

conda create -y -n tmp -c pytorch -c conda-forge pytorch torchvision cudatoolkit=11.6 tqdm
conda activate tmp
pip install torchmetrics
conda env export > environment.yaml

Launch:

python main.py 0
python main.py 1

Outputs:

# python main.py 0
Dist: 0 2
100% 31/31 [00:53<00:00,  1.72s/it, mem=8406025216]
100% 32/32 [00:48<00:00,  1.53s/it, mem=8406665728]
100% 31/31 [00:49<00:00,  1.60s/it, mem=8407306240]
100% 32/32 [00:50<00:00,  1.59s/it, mem=8407946752]
100% 31/31 [00:50<00:00,  1.63s/it, mem=8407946752]
100% 32/32 [00:49<00:00,  1.56s/it, mem=8407946752]
100% 31/31 [00:50<00:00,  1.62s/it, mem=8407946752]
100% 32/32 [00:50<00:00,  1.57s/it, mem=8408587264]
100% 31/31 [00:51<00:00,  1.66s/it, mem=8409227776]
100% 32/32 [00:48<00:00,  1.51s/it, mem=8409868288]

# python main.py 1                                                                                                                   
Dist: 1 2
100% 31/31 [00:53<00:00,  1.71s/it, mem=8406025216]
100% 32/32 [00:50<00:00,  1.58s/it, mem=8406665728]
100% 31/31 [00:49<00:00,  1.60s/it, mem=8407306240]
100% 32/32 [00:50<00:00,  1.57s/it, mem=8407946752]
100% 31/31 [00:50<00:00,  1.63s/it, mem=8407946752]
100% 32/32 [00:49<00:00,  1.55s/it, mem=8407946752]
100% 31/31 [00:50<00:00,  1.62s/it, mem=8407946752]
100% 32/32 [00:49<00:00,  1.54s/it, mem=8408587264]
100% 31/31 [00:51<00:00,  1.66s/it, mem=8409227776]
100% 32/32 [00:52<00:00,  1.64s/it, mem=8409868288]

Attachments:

baldassarreFe avatar Oct 26 '22 10:10 baldassarreFe

@justusschock, mind having a look at it? :otter:

Borda avatar Nov 07 '22 11:11 Borda

Hi @baldassarreFe I finally was able to deep dive into this issue and can reproduce the error with the code you have provided. Thanks for that :]

The underlying issue here is how garbage collection work in python. When you del miou in the code it does not immediately delete the metric object, it only marks it ready to be deleted. If you really want to free up the memory you need to do:

del miou
gc.collect()

using the garbage collection module. Additionally, since you are running on GPU you will need to call torch.cuda.empty_cache() to release the memory for other use. Therefore, in total do:

del miou
gc.collect()
torch.cuda.empty_cache()

That said, there is a much simpler approach for your case, by just reusing the same metric. This is an adjusted version of you main function using metric.reset() instead that have constant memory usage:

def main():
    ... # do distributed and data loading stuff

    # initialize metrice once
    miou = torchmetrics.JaccardIndex(num_classes=200, average='macro').to(device)
    for epoch in range(10):
        with tqdm.tqdm(dl_train, ncols=0) as bar:
            for i, (pred, target) in enumerate(bar):
                pred = pred.to(device)
                target = target.to(device)
                miou.update(pred, target)
                miou.compute()
        miou.reset()  # reset

        with tqdm.tqdm(dl_val, ncols=0) as bar:
            for i, (pred, target) in enumerate(bar):
                pred = pred.to(device)
                target = target.to(device)
                miou.update(pred, target)
        miou.compute()
        miou.reset()  # reset

Hope this clears up any confusion :]

SkafteNicki avatar Nov 21 '22 08:11 SkafteNicki

Hello @SkafteNicki, thanks for looking into the issue and for the proposed solution. However, I wonder if there is some underlying bug with memory management that we are overlooking here.

Here's why:

  • As far as I understand, python's garbage collection and GPU memory caching simply mark objects for deletion. When memory is needed, marked objects are collected. With this behavior we should observe memory consumption increase but we should never get OOM errors. Instead I've had many runs crash due to OOM.

  • The memory increase and later OOM errors only happen in a distributed setting. In single-process mode memory is correctly freed and reused, but multiprocessing interferes with that somehow.

If torchmetrics is not doing anything weird with memory management in DDP, can it be a bug with pytorch?

baldassarreFe avatar Nov 21 '22 20:11 baldassarreFe

It may very well be something on Pytorch side on how they interact with the garbage collector. I can also get the code working if I do something like:

del miou._defaults
del miou.confmat
del miou

where I go in and explicit delete the tensors that are not cleaned up by the garbage collector normally. This of cause requires some knowledge about the internals of torchmetrics.

SkafteNicki avatar Nov 22 '22 07:11 SkafteNicki

All right, it's probably related to pytorch then. I don't have enough knowledge of pytorch's internals to open an issue there, but let's hope it gets fixed in future releases. Meanwhile, I'll work around the OOM errors using your suggestions, thanks!

baldassarreFe avatar Nov 22 '22 09:11 baldassarreFe

@baldassarreFe you are welcome, always happy to help and sorry it took so long to debug. I try to keep this in mind if we ever run into other problems regarding OOM. Closing issue.

SkafteNicki avatar Nov 22 '22 09:11 SkafteNicki