google-research
google-research copied to clipboard
Discrepancy with CMMD paper
Hi, I think there is a bug in the CMMD implementation with respect to the paper.
In the paper, the MMD distance is
However, looking at the code, $K_{X,X}$ and $K_{Y,Y}$ are $n \times n$ and $m \times m$ respectively
https://github.com/google-research/google-research/blob/583d3178157a3dc1eaec04935387ec797004f09b/cmmd/distance.py#L59
This is then reduced using mean() but this is not correct no? Here, the factor is $\frac{1}{n^2}$ where it should be $\frac{1}{n(n-1)}$ to match the paper equation.
Furthermore, since the diagonal is not masked, there is a constant bias of $\frac{1}{n-1}+\frac{1}{m-1}$
An easy fix would be:
n, dx= x.shape
m, dy = y.shape
k_xx = jnp.sum(
jnp.exp(
-gamma
* (
-2 * jnp.matmul(x, x.T)
+ jnp.expand_dims(x_sqnorms, 1)
+ jnp.expand_dims(x_sqnorms, 0)
)
)
)/(n*(n-1)) - 1/((n-1))
k_xy = jnp.sum(
jnp.exp(
-gamma
* (
-2 * jnp.matmul(x, y.T)
+ jnp.expand_dims(x_sqnorms, 1)
+ jnp.expand_dims(y_sqnorms, 0)
)
)
)/(m*(m-1))- 1/((m-1))
Hi @nicolas-dufour, thanks for reporting this. This code implements the minimum-variance version of MMD estimator as explained in the function docstring: https://github.com/google-research/google-research/blob/583d3178157a3dc1eaec04935387ec797004f09b/cmmd/distance.py#L35. We did this in our codebase since most MMD implementation out there implement this version and we did not want to introduce a confusion.
Note however, that the unbiased version and the minimum-variance version are almost identical as explained in the docstring (for COCO 30K benchmark m = n = 30,000). We will explain this in the paper.
Hi @sadeepj , Thanks for the quick answer.
Correct me if i'm wrong but when considering the scaling, the bias term is not negligible For 30K samples, it's a bias of $1000 * \frac{2}{30,000} = \frac{1}{15}=0.06$.
This is amplified for a more data-efficient version of 5K, resulting in a bias of $\frac{2}{5}=0.4$! This is not negligible when you consider the gap between the 5K and the 30K metric in Figure 8 where the gap between the 2 data points is visually 0.03 at most.
I am missing something here? Thanks for the help!
same question