pytorch-complex-tensor icon indicating copy to clipboard operation
pytorch-complex-tensor copied to clipboard

Unofficial complex tensor and scalar support for Pytorch

Pytorch Complex Tensor

Unofficial complex Tensor support for Pytorch

PyPI version

How it works

Treats first half of tensor as real, second as imaginary. A few arithmetic operations are implemented to emulate complex arithmetic. Supports gradients.


pip install pytorch-complex-tensor


Easy import

from pytorch_complex_tensor import ComplexTensor

Init tensor

# equivalent to:
# np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)
C = ComplexTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
C.requires_grad = True

Pretty printing

# tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],
#         ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])

handles absolute value properly for complex tensors

# complex absolute value implementation
# tensor([[3.1623, 3.1623, 3.1623],
#         [4.4721, 4.4721, 4.4721]], grad_fn=<SqrtBackward>)

prints correct sizing treating first half of matrix as real, second as imag

# torch.Size([2, 3])

multiplies both complex and real tensors

# show matrix multiply with real tensor
# also works with complex tensor
x = torch.Tensor([[3, 3], [4, 4], [2, 2]])
xy =
# tensor([['(9.0+27.0j)' '(9.0+27.0j)'],
#         ['(18.0+36.0j)' '(18.0+36.0j)']])

reduce ops return ComplexScalar

xy = xy.sum()

# this is now a complex scalar (thin wrapper with .real, .imag)
# pytorch_complex_tensor.complex_scalar.ComplexScalar

# (54+126j)

which can be used for gradients without breaking anything... (differentiates wrt the real part)

# calculate dxy / dC
# for complex scalars, grad is wrt the real part
# tensor([['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)'],
#         ['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)']])

supports all section ops...

print(C[0, 0:-2, ...])
print(C[0, ..., 0])

Supported ops:

Operation complex tensor real tensor complex scalar real scalar
addition Y Y Y Y
subtraction Y Y Y Y
multiply Y Y Y Y
mm Y Y Y Y
abs Y - - -
t Y - - -
grads Y Y Y Y