complexPyTorch copied to clipboard
Complex Convolution conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) where W, x and b are all complex inputs. With Gauss Trick: a = conv(Wr, xr, br), b = conv(Wi, xi, 0), c = conv(Wr + Wi, xr + xi, bi + br) conv(W, x, b) = a - b + i(c - a - b)
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
class ComplexConvTranspose1dn(nn.ConvTranspose1d):
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
i_r = input.real
i_i = input.imag
w_r = self.weight.real
w_i = self.weight.imag
b_r = self.bias.real
b_i = self.bias.imag
a = F.conv_transpose1d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
b = F.conv_transpose1d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
c = F.conv_transpose1d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)
return torch.complex(a - b, c - a - b)
class ComplexConvTranspose2dn(nn.ConvTranspose2d):
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
i_r = input.real
i_i = input.imag
w_r = self.weight.real
w_i = self.weight.imag
b_r = self.bias.real
b_i = self.bias.imag
a = F.conv_transpose2d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
b = F.conv_transpose2d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
c = F.conv_transpose2d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)
return torch.complex(a - b, c - a - b)
class ComplexConvTranspose3dn(nn.ConvTranspose3d):
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 3
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
i_r = input.real
i_i = input.imag
w_r = self.weight.real
w_i = self.weight.imag
b_r = self.bias.real
b_i = self.bias.imag
a = F.conv_transpose3d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
b = F.conv_transpose3d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
c = F.conv_transpose3d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)
return torch.complex(a - b, c - a - b)