POT icon indicating copy to clipboard operation
POT copied to clipboard

DistributedDataParallel

Open dgm2 opened this issue 2 years ago • 2 comments

It seems that a DistributedDataParallel (DDP) pytorch setup is not supported in OT - specifically on emd2 computation. Any workarounds ideas for making this working? or any example for multi-gpu setups for OT?

ideally, I would like to make OT working with this torch setup https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py

Many thanks

example of failed DDP

  ot.emd2(a, b, dist)
  File "/python3.8/site-packages/ot/lp/__init__.py", line 468, in emd2
    nx = get_backend(M0, a0, b0)
  File "/python3.8/site-packages/ot/backend.py", line 168, in get_backend
    return TorchBackend()
  File "/python3.8/site-packages/ot/backend.py", line 1517, in __init__
    self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable

my current workaround is: changing self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) to self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device=device_id)) passing device id from backend, recompiling this OT from source.

dgm2 avatar Jun 06 '22 10:06 dgm2

Hello @dgm2 ,

This workaround works? Note that the list is here mainly for debugging and tests (so that we can rub them on all available devices) so I'm a bit surprised if this is the only bottleneck for running POT with DPP.

We are obviously interested in your contribution if you manage to manage it work properly (we don not have multiple GPU so it is a bit hard to implement and debug on our side), probably the device device_id should be detected automatically whene using get_backend and creation, the back-ends should not need parameters to remain practical to use.

rflamary avatar Jun 07 '22 06:06 rflamary

Hello @dgm2, Could you provide us with the exact code you used to get this error ? I ran https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py with 4 GPUs and ot.emd2 as the loss function, yet did not get any error, everything seems to have run smoothly whether the distribution was performed with torch or slurm.

ncassereau-idris avatar Jun 09 '22 08:06 ncassereau-idris