torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

real is not implemented for tensors with non-complex dtypes

Open PengleiGao opened this issue 2 years ago • 2 comments

Hi, the dtype of the input is float32. there is no real of the input. the problem raises here "t = t.real.to(y.abs().dtype)" How to solve this problem? 1678256119272

PengleiGao avatar Mar 08 '23 06:03 PengleiGao

这个的主要原因是因为t应该是是一个复数张量(最少二维),但是在神经常微分方程,t一般是个实数张量(一维),所以只需要把这个t变成二维复数张量即可,虚部当然是一个0了。

 if t.numel() == 1:
     mid = [t.item(),0.0]
     t = torch.tensor(mid)
 t = torch.view_as_complex(t)
 t = t.real.to(y.abs().dtype)

DrKarlWu avatar Apr 02 '23 16:04 DrKarlWu

You should upgrade your PyTorch. The more recent versions (I think versions >= 1.6?) have .real implemented for non-complex tensor types. In the meantime, I'll work on a fix for the older PyTorch versions.

rtqichen avatar Apr 06 '23 15:04 rtqichen