complexPyTorch
complexPyTorch copied to clipboard
Complex MSELoss
Similar to torch.nn.MSELoss().
I guess the function is pretty obvious as seen in https://github.com/pytorch/pytorch/issues/46642
def complex_mse_loss(output, target):
return (0.5*(output - target)**2).mean(dtype=torch.complex64)