clean-fid
clean-fid copied to clipboard
A better way to compute the FID
Hello, I think the following implementation of the Fréchet distance is faster than the current one and would allow to drop the scipy
dependency.
def frechet_distance(mu_x: Tensor, sigma_x: Tensor, mu_y: Tensor, sigma_y: Tensor) -> Tensor:
a = (mu_x - mu_y).square().sum(dim=-1)
b = sigma_x.trace() + sigma_y.trace()
c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum(dim=-1)
return a + b - 2 * c
The implementation is based on two facts:
- The trace of $A$ equals the sum of its eigenvalues.
- The eigenvalues of $\sqrt{A}$ are the square-roots of the eigenvalues of $A$.