google-research
google-research copied to clipboard
Discrepancy with CMMD paper
Hi, I think there is a bug in the CMMD implementation with respect to the paper.
In the paper, the MMD distance is
However, looking at the code, $K_{X,X}$ and $K_{Y,Y}$ are $n \times n$ and $m \times m$ respectively
https://github.com/google-research/google-research/blob/583d3178157a3dc1eaec04935387ec797004f09b/cmmd/distance.py#L59
This is then reduced using mean() but this is not correct no? Here, the factor is $\frac{1}{n^2}$ where it should be $\frac{1}{n(n-1)}$ to match the paper equation.
Furthermore, since the diagonal is not masked, there is a constant bias of $\frac{1}{n-1}+\frac{1}{m-1}$
An easy fix would be:
n, dx= x.shape
m, dy = y.shape
k_xx = jnp.sum(
jnp.exp(
-gamma
* (
-2 * jnp.matmul(x, x.T)
+ jnp.expand_dims(x_sqnorms, 1)
+ jnp.expand_dims(x_sqnorms, 0)
)
)
)/(n*(n-1)) - 1/((n-1))
k_xy = jnp.sum(
jnp.exp(
-gamma
* (
-2 * jnp.matmul(x, y.T)
+ jnp.expand_dims(x_sqnorms, 1)
+ jnp.expand_dims(y_sqnorms, 0)
)
)
)/(m*(m-1))- 1/((m-1))