MONAI
MONAI copied to clipboard
MetaTensor and DistributedDataParallel. bug (SyncBatchNormBackward is a view and is being modified inplace)
When upgrading from MONAI 0.9.0 to 1.0.0, my 3D segmentation code fails due to (most likely) new MetaTensor in transforms, when using DistributedDataParallel (multi-gpu)
the error is
RuntimeError: Output 0 of SyncBatchNormBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
same issue was reported here (but for 2D MIL classification) https://github.com/Project-MONAI/MONAI/discussions/5081 and https://github.com/Project-MONAI/MONAI/issues/5198
I've traced it down to this commit https://github.com/Project-MONAI/MONAI/pull/4548/commits/63e36b6cb41e163024729010534cd9363c6356dc (prior to it, the code is working fine)
It seems the issue is that dataloader returns data as MetaTensor (and not torch.Tensor as before) e.g. here https://github.com/Project-MONAI/tutorials/blob/main/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py#L51 both data and target are MetaTensor types
if converting explicitly (on gpu or cpu):
data = torch.Tensor(data)
target = torch.Tensor(target)
then the code runs fine, but a bit slower. It seems there is something wrong with MetaTensor
thanks for reporting, I'm able to reproduce with torchrun --nnodes=1 --nproc_per_node=2 test.py
using this test.py
:
import torch.distributed as dist
import torch
from torchvision import models
from monai.data import MetaTensor
torch.autograd.set_detect_anomaly(True)
def run():
dist.init_process_group("nccl")
rank = dist.get_rank()
print(f"rank {rank}")
device = rank
mod = models.resnet50(pretrained=True).to(device)
optim = torch.optim.Adam(mod.parameters(), lr=1e-3)
z1 = MetaTensor(torch.zeros(1, 3, 128, 128)).to(device)
mod = torch.nn.SyncBatchNorm.convert_sync_batchnorm(mod)
mod = torch.nn.parallel.DistributedDataParallel(mod, device_ids=[rank], output_device=rank)
out = mod(z1)
print(out.shape)
loss = (out**2).mean()
optim.zero_grad()
loss.backward()
optim.step()
print("Stepped.")
if __name__ == "__main__":
run()
I'll submit a PR to fix this.
looks like a pytorch issue, I created a bug report (https://github.com/pytorch/pytorch/issues/86456).
Because the bug in the upstream has not yet been fixed, this ticket should be kept.