pytorch-complex-tensor
pytorch-complex-tensor copied to clipboard
Multiplication of complex tensors
Hi! I was trying to perform some operations with your package, but I could not figure out element-wise multiplication.
Suppose the following code:
import torch
from pytorch_complex_tensor import ComplexTensor
A = torch.randn(10, 4, 5, 3)
FA = torch.rfft(A, 2, onesided=False) # get complex tensor from 2D rFFT
FA.shape # torch.Size([10, 4, 5, 3, 2])
C = ComplexTensor(FA.transpose(4,3)) # tranpose last dimension to match input format
C.shape # torch.Size([10, 4, 5, 1, 3])
CC = C * C
CC.shape # torch.Size([20, 4, 5, 0, 3])
The last output is null. How should element-wise operation work? Thanks!