torchdiffeq
torchdiffeq copied to clipboard
real is not implemented for tensors with non-complex dtypes
Hi, the dtype of the input is float32. there is no real of the input.
the problem raises here "t = t.real.to(y.abs().dtype)"
How to solve this problem?

这个的主要原因是因为t应该是是一个复数张量(最少二维),但是在神经常微分方程,t一般是个实数张量(一维),所以只需要把这个t变成二维复数张量即可,虚部当然是一个0了。
if t.numel() == 1:
mid = [t.item(),0.0]
t = torch.tensor(mid)
t = torch.view_as_complex(t)
t = t.real.to(y.abs().dtype)
You should upgrade your PyTorch. The more recent versions (I think versions >= 1.6?) have .real implemented for non-complex tensor types. In the meantime, I'll work on a fix for the older PyTorch versions.