functorch
functorch copied to clipboard
var and std decompositions are incorrect for complex numbers.
std decomposition is likely incorrect since it calls into var.
@register_decomposition(aten.var.correction)
def var_decomposition(x: Tensor, dims: Optional[List[int]], correction: int = 0, keepdim: bool = False):
if dims is None:
dims = []
if len(dims) == 0:
n = x.numel()
else:
n = 1
for dim in dims:
n *= x.shape[dim]
mean = torch.mean(x, dims, True)
sub = x - mean
sq = sub * sub
sum = torch.sum(sq, dims, keepdim)
if correction:
n = n - correction
return sum / n
@register_decomposition(aten.std.correction)
def std_decomposition(x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False):
return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))