complexPyTorch
complexPyTorch copied to clipboard
ComplexDropout2d Device Error
Hi, thank you for the nice library.
There seems to be a small mistake in the complexPyTorch.complexLayers.ComplexDropout2d layer, which gives a device mismatch error (torch version 2.0.1+cu118):
""" .... line 106, in complex_dropout return mask*input RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! """
I managed to solve it by simply moving the mask on the right device in complexPyTorch.complexFunctions.complex_dropout2d as follows
` def complex_dropout2d(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
device = input.device
mask = torch.ones(*input.shape, dtype = torch.float32, device = device)
mask = torch.nn.functional.dropout2d(mask, p, training)*1/(1-p)
mask.type(input.dtype)
mask = mask.to(device) # Line added
return mask*input`
Best!
the same for all the dropouts. any updates for the official fix?
Hi, I can look at it when I have some time but I am not working anymore on this code, which I consider obsolete and do not need anymore due to the implementation of complex tensors in the current versions of PyTorch. Do not hesitate to fork and why not make a pull request though if you need such changes, I would treat it. Best,