MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

import cv2 and DistributedDataParallel. (FIND was unable to find an engine)

Open myron opened this issue 2 years ago • 13 comments

EDIT: the bug is reproducable in the newest nvidia/pytorch:22.09-py3 docker container, but is not reproducible in older container (older pytorch/cudnn)

Something in MetaTensor makes DistributedDataParallel fail (this is in addition to this bug https://github.com/Project-MONAI/MONAI/issues/5283)

For example this code fails

import torch.distributed as dist
import torch

from monai.data import MetaTensor
#from monai.config.type_definitions import NdarrayTensor

from torch.cuda.amp import autocast  
torch.autograd.set_detect_anomaly(True)

def main():

    ngpus_per_node = torch.cuda.device_count()
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))

def main_worker(rank, ngpus_per_node):

    print(f"rank {rank}")

    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=ngpus_per_node, rank=rank)
    torch.backends.cudnn.benchmark = True

    model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    x = torch.ones(1, 1, 192, 192, 192).to(rank)
    with autocast(enabled=True):
        out = model(x)

    print("Done.", out.shape)

if __name__ == "__main__":
    main()

with error

-- Process 6 terminated with the following error:                                                                                                               
Traceback (most recent call last):                                                                                                                              
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap                                                               
    fn(i, *args)                                                                                                                                                
  File "/mnt/amproj/Code/automl/tasks/hecktor22/autoconfig_segresnet/test_monai.py", line 29, in main_worker                                                    
    out = model(x)                                                                                                                                              
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl                                                            
    return forward_call(*input, **kwargs)                                                                                                                       
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1015, in forward                                                         
    output = self._run_ddp_forward(*inputs, **kwargs)                                                                                                           
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 976, in _run_ddp_forward                                                 
    return module_to_run(*inputs[0], **kwargs[0])                                                                                                               
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl                                                            
    return forward_call(*input, **kwargs)                                                                                                                       
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 613, in forward                                                                  
    return self._conv_forward(input, self.weight, self.bias)                                                                                                    
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
RuntimeError: FIND was unable to find an engine to execute this computation

The MetaTensor is actually never used/initialized here, but something it it (or it's imports) makes the code fail. Since we import MetaTensor everywhere, any code with it fails. I've traced it down to this import (inside of MetaTensor.py) from monai.config.type_definitions import NdarrayTensor

importing this line also makes the code fail.

Somehow it confuses conv3d operation, and possibly other operations

myron avatar Oct 07 '22 21:10 myron

this might be related to the cuda/cudnn versions. but I can't reproduce locally, with cuda 10.2 or 11.7, cudnn 7605 or 8500. could you try to reproduce with a fresh environment and if possible report back python -c 'import monai; monai.config.print_debug_info()'?

wyli avatar Oct 07 '22 21:10 wyli

This is on NVIDIA V100 16gb x 8 ngc instance, using NVIDIA pytorch contrainer (either nvidian/pytorch:22.09-py3 or nvidia/pytorch:22.0-py3) and latest MONAI 1.0.0 (via pip install monai)

Printing MONAI config...

MONAI version: 1.0.0
Numpy version: 1.22.2
Pytorch version: 1.13.0a0+d0d6b1f
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 170093375ce29267e45681fcec09dfa856e1d7e7
MONAI __file__: /opt/conda/lib/python3.8/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: 2.10.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.14.0a0
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.2
pandas version: 1.4.4
einops version: 0.5.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



Printing system config...

System: Linux
Linux version: Ubuntu 20.04.5 LTS
Platform: Linux-5.4.0-109-generic-x86_64-with-glibc2.10
Processor: x86_64
Machine: x86_64
Python version: 3.8.13
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 40
Num logical CPUs: 80
Num usable CPUs: 80
CPU usage (%): [0.6, 0.3, 0.6, 1.3, 0.0, 0.6, 0.3, 0.6, 2.9, 1.0, 1.6, 0.6, 0.6, 1.6, 0.3, 0.6, 1.9, 0.3, 0.3, 0.6, 0.3, 0.0, 0.0, 0.3, 0.3, 0.3, 0.6, 0.6, 1.3, 7.3, 0.3, 0.9, 0.0, 0.3, 0.6, 0.3, 0.3, 0.3, 0.0, 1.6, 0.0, 0.0, 0.3, 1.3, 0.6, 1.3, 1.9, 1.3, 0.6, 7.3, 6.6, 0.6, 0.0, 0.6, 0.3, 0.3, 0.3, 0.3, 2.5, 0.9, 0.3, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.0, 0.3, 0.3, 0.3, 0.3, 0.0, 0.3, 0.6, 0.3, 0.3, 0.3, 0.9, 100.0]
CPU freq. (MHz): 2694
Load avg. in last 1, 5, 15 mins (%): [1.6, 1.8, 2.3]
Disk usage (%): 46.2
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 503.8
Available memory (GB): 484.9
Used memory (GB): 16.2


Printing GPU config...

Num GPUs: 8
Has CUDA: True
CUDA version: 11.8
cuDNN enabled: True
cuDNN version: 8600
Current device: 0
Library compiled for CUDA architectures: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'compute_90']
GPU 0 Name: Tesla V100-SXM2-16GB-N
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 80
GPU 0 Total memory (GB): 15.8
GPU 0 CUDA capability (maj.min): 7.0
GPU 7 Name: Tesla V100-SXM2-16GB-N
GPU 7 Is integrated: False
GPU 7 Is multi GPU board: False
GPU 7 Multi processor count: 80
GPU 7 Total memory (GB): 15.8
GPU 7 CUDA capability (maj.min): 7.0

myron avatar Oct 07 '22 21:10 myron

This is on NVIDIA V100 16gb x 8 ngc instance, using NVIDIA pytorch contrainer (either nvidian/pytorch:22.09-py3 or nvidia/pytorch:22.0-py3) and latest MONAI 1.0.0 (via pip install monai)

thanks, 22.09 hasn't been tested yet https://github.com/Project-MONAI/MONAI/issues/5269 cc @Nic-Ma

wyli avatar Oct 07 '22 21:10 wyli

yeah, you're right 22.08 pytorch container is working fine, which includes

CUDA version: 11.7
cuDNN version: 8500
Pytorch version: 1.13.0a0+d321be6

So it's related to newer cudnn or newer pytorch (in combination with monai==1.0.0)

But even with newest 22.09 container , monai==0.9.0 is working fine (only the 1.0.0 fails)

myron avatar Oct 07 '22 22:10 myron

Hi @myron , @wyli ,

I tried to execute the test program on V100-32G with MONAI latest and 22.09 docker, got below output:

root@apt-sh-ai:/workspace/data/medical/MONAI# python test_ddp.py 
rank 0
rank 1
2022-10-08 10:13:58,995 - Added key: store_based_barrier_key:1 to store for rank: 0
2022-10-08 10:13:59,005 - Added key: store_based_barrier_key:1 to store for rank: 1
2022-10-08 10:13:59,005 - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
2022-10-08 10:13:59,005 - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
is_namedtuple is deprecated, please use the python checks instead
is_namedtuple is deprecated, please use the python checks instead
Traceback (most recent call last):
  File "test_ddp.py", line 32, in <module>
    main()
  File "test_ddp.py", line 13, in main
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/data/medical/MONAI/test_ddp.py", line 27, in main_worker
    out = model(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1015, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 976, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 613, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
RuntimeError: FIND was unable to find an engine to execute this computation

Thanks.

Nic-Ma avatar Oct 08 '22 10:10 Nic-Ma

After further analysis, here is my finding:

  1. Any MONAI import will cause the error, for example, changing from monai.data import MetaTensor to from monai.config.deviceconfig import print_config also shows the error.
  2. If moving the import into subprocessing (in function main_worker()), then everything is fine.
  3. If changing the nccl to gloo, then everything is fine.

As any MONAI import will trigger lots of importing, maybe some CUDA related thing is not shareable in spawn multi-processing.

Thanks.

Nic-Ma avatar Oct 08 '22 11:10 Nic-Ma

Thanks @Nic-Ma . In my analysis, monai=0.9.0 works fine in the newest pytorch container, so it's something specific for monai=1.0.0

it seems that this header import alone causes it from monai.config.type_definitions import NdarrayTensor

myron avatar Oct 08 '22 18:10 myron

it seems it's triggered by import cv2, on driver 470.82.01 and nvcr.io/nvidia/pytorch:22.09-py3 (the root cause is not really from monai...perhaps we report this to the framework team instead).

To reproduce, launch nvcr.io/nvidia/pytorch:22.09-py3, and run python test.py, where test.py has the following content:

import torch.distributed as dist
import torch

import cv2

from torch.cuda.amp import autocast
torch.autograd.set_detect_anomaly(True)

def main():

    ngpus_per_node = torch.cuda.device_count()
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))

def main_worker(rank, ngpus_per_node):

    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=ngpus_per_node, rank=rank)
    torch.backends.cudnn.benchmark = True

    model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    x = torch.ones(1, 1, 192, 192, 192).to(rank)
    with autocast(enabled=True):
        out = model(x)

if __name__ == "__main__":
    main()

output:

root@3512928:/workspace# python test.py
/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:9: UserWarning: is_namedtuple is deprecated, please use the python checks instead
  warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
Traceback (most recent call last):
  File "test.py", line 27, in <module>
    main()
  File "test.py", line 12, in main
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/test.py", line 24, in main_worker
    out = model(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1015, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 976, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 613, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
RuntimeError: FIND was unable to find an engine to execute this computation

wyli avatar Oct 08 '22 21:10 wyli

I get the same error if I import from monai.config.type_definitions import NdarrayTensor

instead of import cv2

and that import doesn't import cv2, so it seems there are several ways to trigger this error

myron avatar Oct 09 '22 01:10 myron

Hi @myron ,

The MONAI import logic is different, we import all the things even you only import one component: https://github.com/Project-MONAI/MONAI/blob/dev/monai/init.py#L48 So it may call the import cv2 somewhere in the codebase, for example: https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/video_dataset.py#L28

Thanks.

Nic-Ma avatar Oct 09 '22 01:10 Nic-Ma

@Nic-Ma thanks for the reply. I see..

We should reconsider this logic. If someone wants to import only a small component, why do we need to import Everything. This seems slow, and can lead to bugs, which is "hard-to-debug" - like this bug, in the future.

myron avatar Oct 09 '22 02:10 myron

The current import is not lazy for the first run, but it always walks through the modules in the same import ordering and easily avoids circular imports. I tried to make it optional but dont have an idea about dealing with the circular imports.

wyli avatar Oct 09 '22 14:10 wyli

Hi @myron @wyli ,

After more analysis, I found that this issue only occurs when you set: torch.backends.cudnn.benchmark = True To unblock your work, I think you can remove this line or set it to False so far.

Thanks.

Nic-Ma avatar Oct 11 '22 08:10 Nic-Ma

thank you for the reply,

cudnn.benchmark selects the best kernel variant , if we don't use it we may have 20% performance drop. I don't think it's an acceptable long term solution.

I can use 22.08 container, until we find a solution (it's acceptable to me for now)

But we need a solution, that doesn't compromise the efficiency. And once again in monai=0.9.0 it's working fine with 22.09 with all these options, so it's something new in monai=1.0.0 that triggers it.

I don't think we can close this issue yet

myron avatar Oct 11 '22 17:10 myron

It's already been addressed by https://github.com/Project-MONAI/MONAI/pull/5293 (by not importing cv2, https://github.com/Project-MONAI/MONAI/blob/dev/monai/init.py#L50), with a test case included. What Nic mentions is a possible alternative solution in case cv2 is imported for some other purposes.

wyli avatar Oct 11 '22 20:10 wyli

I see, very good, thank you guys

myron avatar Oct 11 '22 22:10 myron