MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

MetaTensor and DistributedDataParallel. bug (SyncBatchNormBackward is a view and is being modified inplace)

Open myron opened this issue 2 years ago • 3 comments

When upgrading from MONAI 0.9.0 to 1.0.0, my 3D segmentation code fails due to (most likely) new MetaTensor in transforms, when using DistributedDataParallel (multi-gpu)

the error is RuntimeError: Output 0 of SyncBatchNormBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

same issue was reported here (but for 2D MIL classification) https://github.com/Project-MONAI/MONAI/discussions/5081 and https://github.com/Project-MONAI/MONAI/issues/5198

I've traced it down to this commit https://github.com/Project-MONAI/MONAI/pull/4548/commits/63e36b6cb41e163024729010534cd9363c6356dc (prior to it, the code is working fine)

It seems the issue is that dataloader returns data as MetaTensor (and not torch.Tensor as before) e.g. here https://github.com/Project-MONAI/tutorials/blob/main/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py#L51 both data and target are MetaTensor types

if converting explicitly (on gpu or cpu):

data = torch.Tensor(data)
target = torch.Tensor(target)

then the code runs fine, but a bit slower. It seems there is something wrong with MetaTensor

myron avatar Oct 07 '22 03:10 myron

thanks for reporting, I'm able to reproduce with torchrun --nnodes=1 --nproc_per_node=2 test.py using this test.py:

import torch.distributed as dist

import torch
from torchvision import models
from monai.data import MetaTensor

torch.autograd.set_detect_anomaly(True)

def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"rank {rank}")
    device = rank

    mod = models.resnet50(pretrained=True).to(device)
    optim = torch.optim.Adam(mod.parameters(), lr=1e-3)
    z1 = MetaTensor(torch.zeros(1, 3, 128, 128)).to(device)

    mod = torch.nn.SyncBatchNorm.convert_sync_batchnorm(mod)
    mod = torch.nn.parallel.DistributedDataParallel(mod, device_ids=[rank], output_device=rank)

    out = mod(z1)
    print(out.shape)
    loss = (out**2).mean()

    optim.zero_grad()
    loss.backward()
    optim.step()

    print("Stepped.")

if __name__ == "__main__":
    run()

I'll submit a PR to fix this.

wyli avatar Oct 07 '22 08:10 wyli

looks like a pytorch issue, I created a bug report (https://github.com/pytorch/pytorch/issues/86456).

wyli avatar Oct 07 '22 10:10 wyli

Because the bug in the upstream has not yet been fixed, this ticket should be kept.

KumoLiu avatar Dec 20 '23 06:12 KumoLiu