functorch icon indicating copy to clipboard operation
functorch copied to clipboard

var and std decompositions are incorrect for complex numbers.

Open Chillee opened this issue 3 years ago • 0 comments

std decomposition is likely incorrect since it calls into var.

@register_decomposition(aten.var.correction)
def var_decomposition(x: Tensor, dims: Optional[List[int]], correction: int = 0, keepdim: bool = False):
    if dims is None:
        dims = []
    if len(dims) == 0:
        n = x.numel()
    else:
        n = 1
        for dim in dims:
            n *= x.shape[dim]

    mean = torch.mean(x, dims, True)
    sub = x - mean
    sq = sub * sub
    sum = torch.sum(sq, dims, keepdim)

    if correction:
        n = n - correction

    return sum / n


@register_decomposition(aten.std.correction)
def std_decomposition(x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False):
    return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))

Chillee avatar Apr 29 '22 03:04 Chillee