google-research icon indicating copy to clipboard operation
google-research copied to clipboard

Discrepancy with CMMD paper

Open nicolas-dufour opened this issue 5 months ago • 2 comments

Hi, I think there is a bug in the CMMD implementation with respect to the paper.

In the paper, the MMD distance is image

However, looking at the code, $K_{X,X}$ and $K_{Y,Y}$ are $n \times n$ and $m \times m$ respectively

https://github.com/google-research/google-research/blob/583d3178157a3dc1eaec04935387ec797004f09b/cmmd/distance.py#L59

This is then reduced using mean() but this is not correct no? Here, the factor is $\frac{1}{n^2}$ where it should be $\frac{1}{n(n-1)}$ to match the paper equation.

Furthermore, since the diagonal is not masked, there is a constant bias of $\frac{1}{n-1}+\frac{1}{m-1}$

An easy fix would be:

n, dx= x.shape
m, dy = y.shape
k_xx = jnp.sum(
    jnp.exp(
        -gamma
        * (
            -2 * jnp.matmul(x, x.T)
            + jnp.expand_dims(x_sqnorms, 1)
            + jnp.expand_dims(x_sqnorms, 0)
        )
    )
)/(n*(n-1)) - 1/((n-1))
k_xy = jnp.sum(
    jnp.exp(
        -gamma
        * (
            -2 * jnp.matmul(x, y.T)
            + jnp.expand_dims(x_sqnorms, 1)
            + jnp.expand_dims(y_sqnorms, 0)
        )
    )
)/(m*(m-1))- 1/((m-1))

nicolas-dufour avatar Jan 26 '24 13:01 nicolas-dufour