pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

NCCL Backend does not support ComplexFloat data type

Open JALB-epsilon opened this issue 4 years ago • 30 comments

🐛 Describe the bug

I am using Torch 1.10. This is not the code I am working on, but this reproduce the error.

import torch
import os
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def example(rank, world_size):
    # create default process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10, dtype=torch.cfloat).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    x = torch.rand(20,10)
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    # backward pass
    loss_fn(torch.abs(outputs), labels).backward()
    # update parameters
    optimizer.step()
    print(outputs)

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    os.environ['MASTER_ADDR'] =  'localhost' 
    os.environ['MASTER_PORT'] = str(np.random.randint(10000, 20000))
    main()

Output RuntimeError: Input tensor data type is not supported for NCCL process group: ComplexFloat

NCCL backend should support the Complex datatype as in the backpropagation algorithm.

Versions

PyTorch version: 1.10.1 Is debug build: False CUDA used to build PyTorch: 10.2 ROCM used to build PyTorch: N/A

OS: NVIDIA DGX Station (x86_64) GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44) Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.17

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-3.10.0-1160.15.2.el7.x86_64-x86_64-with-glibc2.17 Is CUDA available: True CUDA runtime version: Could not collect

Nvidia driver version: 450.102.04 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.7.6.3 /usr/lib64/libcudnn.so.8.0.5 /usr/lib64/libcudnn_adv_infer.so.8.0.5 /usr/lib64/libcudnn_adv_train.so.8.0.5 /usr/lib64/libcudnn_cnn_infer.so.8.0.5 /usr/lib64/libcudnn_cnn_train.so.8.0.5 /usr/lib64/libcudnn_ops_infer.so.8.0.5 /usr/lib64/libcudnn_ops_train.so.8.0.5 HIP runtime version: N/A MIOpen runtime version: N/A

Versions of relevant libraries: [pip3] numpy==1.21.2 [pip3] torch==1.10.1 [pip3] torchaudio==0.10.1 [pip3] torchvision==0.11.2 [conda] blas 1.0 mkl [conda] cudatoolkit 10.2.89 hfd86e86_1 [conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] mkl 2021.4.0 h06a4308_640 [conda] mkl-service 2.4.0 py38h7f8727e_0 [conda] mkl_fft 1.3.1 py38hd3c417c_0 [conda] mkl_random 1.2.2 py38h51133e4_0 [conda] numpy 1.21.2 py38h20f2e39_0 [conda] numpy-base 1.21.2 py38h79a1101_0 [conda] pytorch 1.10.1 py3.8_cuda10.2_cudnn7.6.5_0 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torchaudio 0.10.1 py38_cu102 pytorch [conda] torchvision 0.11.2 py38_cu102 pytorch

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @anjali411 @dylanbespalko @mruberry @Lezcano @nikitaved

JALB-epsilon avatar Jan 21 '22 00:01 JALB-epsilon

oops, we should definitely fix this. As a workaround, view the complex to real and send that to nccl.

ezyang avatar Feb 10 '22 03:02 ezyang

Removed "triaged" to discuss in distributed oncall

rohan-varma avatar Feb 10 '22 19:02 rohan-varma

@ezyang Sorry, I assume you mean use torch.view_as_real, but I'm unsure how to modify the above DDP example to use it, or do you mean for a custom distributed setup?

rmchurch avatar Feb 23 '22 16:02 rmchurch

oops, we should definitely fix this. As a workaround, view the complex to real and send that to nccl.

if the tensor is conjugated, you'd also need to materialize the conjugation before you can view it as real.

anjali411 avatar Feb 23 '22 18:02 anjali411

@ezyang Sorry, I assume you mean use torch.view_as_real, but I'm unsure how to modify the above DDP example to use it, or do you mean for a custom distributed setup?

I guess we need to fix this inside DDP packages, will work on this fix

zhaojuanmao avatar Feb 28 '22 20:02 zhaojuanmao

Has this issue been fixed yet?

Remponator avatar Jun 21 '22 12:06 Remponator

cc. @kumpera who was working on #74039

anjali411 avatar Jun 22 '22 20:06 anjali411

Has this issue been fixed yet?

z870609382 avatar Sep 14 '22 12:09 z870609382

Has this issue been fixed yet?

Ayiing avatar Sep 15 '22 09:09 Ayiing

Is there any update on this? I'm running into this issue currently.

aclifton314 avatar Oct 20 '22 16:10 aclifton314

@ezyang Sorry, I assume you mean use torch.view_as_real, but I'm unsure how to modify the above DDP example to use it, or do you mean for a custom distributed setup?

Maybe the label 'has workaround' has to be removed, except if somebody can show a workaround for the given example. torch.view_as_real might not be easily applicable to the weights of the model, which would be required for the example.

NikolasMorshuis avatar Oct 27 '22 10:10 NikolasMorshuis

cc. @kumpera @mrshenli are we still planning to push this PR https://github.com/pytorch/pytorch/pull/74039 through?

anjali411 avatar Oct 27 '22 13:10 anjali411

I had exactly the same problem when I try to do distributed training on the cluster. Just want to check if there is a timeline for the solution to push through.

qingkaikong avatar Nov 01 '22 17:11 qingkaikong

Also running into this problem.

Dahoas avatar Nov 16 '22 21:11 Dahoas

@Dahoas, I found the walk-around solution given above works well at this moment before the new changes merge into the master. Try it for now: manually do the complex number operation by using the real and imaginary parts and then use the view_as_complex later.

qingkaikong avatar Nov 17 '22 18:11 qingkaikong

Hi, any update on this issue?

heibaidaolx123 avatar Feb 07 '23 08:02 heibaidaolx123

Hi, what's the status on this?

lucidrains avatar Feb 07 '23 16:02 lucidrains

Can someone also detail the workaround if there are too many blockers for this issue to be completed in a timely manner? Thanks

lucidrains avatar Feb 07 '23 16:02 lucidrains

Is https://github.com/pytorch/pytorch/issues/71613#issuecomment-1034448731 enough detail?

ezyang avatar Feb 07 '23 19:02 ezyang

Not for me, I'm slow 😆 but if that's all you will let on, fine I'll figure it out

@ezyang is this still in progress?

lucidrains avatar Feb 07 '23 19:02 lucidrains

@ezyang Sorry, I assume you mean use torch.view_as_real, but I'm unsure how to modify the above DDP example to use it, or do you mean for a custom distributed setup?

Maybe the label 'has workaround' has to be removed, except if somebody can show a workaround for the given example. torch.view_as_real might not be easily applicable to the weights of the model, which would be required for the example.

alright, i'll be back, and provide a diff for the original issue once / if i figure it out

lucidrains avatar Feb 07 '23 19:02 lucidrains

@ezyang ok, i've figured it out, but the solution is not a solution at all. it means one cannot use any network with parameters of complex float, as noted by the comment above. judging by your curt response, i'm going to guess the answer to my original question is "no" so i will not press further. i'll find another way to train a discriminator i need to automate away music generation

lucidrains avatar Feb 07 '23 21:02 lucidrains

I think this gives you what you want:

class complexModel(nn.Module):
  def _init__(self, rank):
    super (complexModel, self).__init_
    weights1 = torch.rand(10,10,12,12, dtype=torch.cfloat).to(rank)
    self.weights1 = nn.Parameter(torch.view _as_real(weights1))
  def forward(self,x):
    weights1 = torch.view_as_complex(self.weights1)
    return weights1

rmchurch avatar Feb 07 '23 21:02 rmchurch

@rmchurch yes, i'm doing something similar here . at this point, i'll just venture on the path of keeping everything as real and doing all the complex stuff manually

lucidrains avatar Feb 07 '23 21:02 lucidrains

@rmchurch ahh ok, what i was missing is that the pytorch functions can still manage complex representations. thank you for your example

the final running workaround for the issue is:

import torch
import os
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

import torch.nn as nn
import torch.nn.functional as F

class ComplexLinear(nn.Module):
    def __init__(
        self,
        dim,
        dim_out
    ):
        super().__init__()
        linear = nn.Linear(dim, dim_out, dtype = torch.cfloat)
        self.weight = nn.Parameter(torch.view_as_real(linear.weight))
        self.bias = nn.Parameter(torch.view_as_real(linear.bias))

    def forward(self, x):
        weight = torch.view_as_complex(self.weight)
        bias = torch.view_as_complex(self.bias)
        return F.linear(x, weight, bias)

class ModReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

def example(rank, world_size):
    # create default process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(
        ComplexLinear(10, 10),
        ModReLU()
    ).to(rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    x = torch.view_as_complex(torch.rand(20, 10, 2)).to(rank)

    outputs = ddp_model(x)
    labels = torch.randn(20, 10).to(rank)

    # backward pass
    loss_fn(torch.abs(outputs), labels).backward()

    # update parameters
    optimizer.step()
    print(outputs)

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    os.environ['MASTER_ADDR'] =  'localhost' 
    os.environ['MASTER_PORT'] = str(np.random.randint(10000, 20000))
    main()

my use case, in case it helps anyone

class ComplexConv2d(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        kernel_size,
        stride = 1,
        padding = 0
    ):
        super().__init__()
        conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64)
        self.weight = nn.Parameter(torch.view_as_real(conv.weight))
        self.bias = nn.Parameter(torch.view_as_real(conv.bias))

        self.stride = stride
        self.padding = padding

    def forward(self, x):
        weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
        return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)

edit: this will only work for pytorch 1.12+

lucidrains avatar Feb 08 '23 03:02 lucidrains

Hi, is this still going to be fixed? As the workaround proposed above is not always feasible and is often error-prone? It would be much cleaner if this is fixed within the DDP package. Thanks.

wcqc avatar May 18 '23 15:05 wcqc

Ke tells me we won't fix this in NCCL directly; instead, we need to handle converting complex float to just float in userland

ezyang avatar Jun 06 '23 03:06 ezyang

Has there been any progress on this? I'm also in need of DDP for complex-valued NNs, and I am wondering if there are plans for this issue to be resolved.

arthurmccray avatar Mar 14 '24 03:03 arthurmccray

just wondering any fix been done?

alundilong avatar Apr 16 '24 21:04 alundilong

is there any progress on this?

JALB-epsilon avatar Jun 29 '24 17:06 JALB-epsilon