clean-fid icon indicating copy to clipboard operation
clean-fid copied to clipboard

A better way to compute the FID

Open francois-rozet opened this issue 1 year ago • 0 comments

Hello, I think the following implementation of the Fréchet distance is faster than the current one and would allow to drop the scipy dependency.

def frechet_distance(mu_x: Tensor, sigma_x: Tensor, mu_y: Tensor, sigma_y: Tensor) -> Tensor:
    a = (mu_x - mu_y).square().sum(dim=-1)
    b = sigma_x.trace() + sigma_y.trace()
    c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum(dim=-1)

    return a + b - 2 * c

The implementation is based on two facts:

  1. The trace of $A$ equals the sum of its eigenvalues.
  2. The eigenvalues of $\sqrt{A}$ are the square-roots of the eigenvalues of $A$.

francois-rozet avatar Apr 10 '23 19:04 francois-rozet