pytorch-complex-tensor
pytorch-complex-tensor copied to clipboard
Unofficial complex tensor and scalar support for Pytorch
Pytorch Complex Tensor
Unofficial complex Tensor support for Pytorch
How it works
Treats first half of tensor as real, second as imaginary. A few arithmetic operations are implemented to emulate complex arithmetic. Supports gradients.
Installation
pip install pytorch-complex-tensor
Example:
Easy import
from pytorch_complex_tensor import ComplexTensor
Init tensor
# equivalent to:
# np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)
C = ComplexTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
C.requires_grad = True
Pretty printing
print(C)
# tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],
# ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])
handles absolute value properly for complex tensors
# complex absolute value implementation
print(C.abs())
# tensor([[3.1623, 3.1623, 3.1623],
# [4.4721, 4.4721, 4.4721]], grad_fn=<SqrtBackward>)
prints correct sizing treating first half of matrix as real, second as imag
print(C.size())
# torch.Size([2, 3])
multiplies both complex and real tensors
# show matrix multiply with real tensor
# also works with complex tensor
x = torch.Tensor([[3, 3], [4, 4], [2, 2]])
xy = C.mm(x)
print(xy)
# tensor([['(9.0+27.0j)' '(9.0+27.0j)'],
# ['(18.0+36.0j)' '(18.0+36.0j)']])
reduce ops return ComplexScalar
xy = xy.sum()
# this is now a complex scalar (thin wrapper with .real, .imag)
print(type(xy))
# pytorch_complex_tensor.complex_scalar.ComplexScalar
print(xy)
# (54+126j)
which can be used for gradients without breaking anything... (differentiates wrt the real part)
# calculate dxy / dC
# for complex scalars, grad is wrt the real part
xy.backward()
print(C.grad)
# tensor([['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)'],
# ['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)']])
supports all section ops...
print(C[-1])
print(C[0, 0:-2, ...])
print(C[0, ..., 0])
Supported ops:
Operation | complex tensor | real tensor | complex scalar | real scalar |
---|---|---|---|---|
addition | Y | Y | Y | Y |
subtraction | Y | Y | Y | Y |
multiply | Y | Y | Y | Y |
mm | Y | Y | Y | Y |
abs | Y | - | - | - |
t | Y | - | - | - |
grads | Y | Y | Y | Y |