ott icon indicating copy to clipboard operation
ott copied to clipboard

Negative loss in unbalanced Sinkhorn

Open MUCDK opened this issue 2 years ago • 3 comments

Unbalanced Sinkhorn results in negative reg_ot_cost.

import jax
from ott.core.sinkhorn import Sinkhorn
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
from ott.core import LinearProblem

n = 1000
dim = 30

rng = jax.random.PRNGKey(0)
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (n, dim)) + 0.1
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (n,))
a_scaled = a / jnp.sum(a)
b_scaled = b / jnp.sum(b)

lp = LinearProblem(PointCloud(x, y, epsilon=1), a_scaled, b_scaled, tau_a=0.5, tau_b=0.6)
s = Sinkhorn()(lp)

print(s.reg_ot_cost)

The regularised OT cost is -8.937. The more unbalanced, the more negative the OT cost is, which should not be the case in general.

Installation ott-jax: from source.

MUCDK avatar Aug 09 '22 10:08 MUCDK

Hi Dominik, thanks for this.

If you look at the primal cost, which is returned by reg_ot_cost, this includes terms in KLs (those should >0) but also a contribution term corresponding to minus the entropy.

In OTT, we compute the objective using the dual formulation https://github.com/ott-jax/ott/blob/d7521fdbe3e461a2192c8fc1e4080558dd453ee3/ott/core/sinkhorn.py#L142 and that also includes some negative contributions, notably in the div_a and div_b terms, and the total_sum term substratcted at the end.

So, having a negative ent_reg_cost is not a contradiction per se. However, this does not preclude a bug. Have you looked at the transportation matrix that is returned? do you see entries that are unusual?

marcocuturi avatar Aug 17 '22 08:08 marcocuturi

Hi Marco,

Thanks for your response. I might be wrong but reading equation 24 in https://arxiv.org/pdf/1910.12958.pdf the term multiplied by epsilon in the entropic penalization should be non-positive because f_i + g_j <= C_{ij} and hence every element of the right matrix in the inner product is non-positive whereas each term in the left matrix is non-negative.

This holds true for total_sum in the code, in the example above total_sum = -11.217885.

On the other hand, we add this term in the return statement, i.e. the return value (https://github.com/ott-jax/ott/blob/d7521fdbe3e461a2192c8fc1e4080558dd453ee3/ott/core/sinkhorn.py#L198) is div_a + div_b + ot_prob.epsilon * (jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum). Hence, I don't know where the minus before the epsilon in equation (24) is incorporated.

Hence, I would assume that the return statement starting in line 198 should read

div_a + div_b - ot_prob.epsilon * (
      jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum
  )

Apologies if I am wrong.

The transport matrix seems to be not completely wrong at least as for multiple scenarios the PCC with implementations from WOT and POT are larger than 0.98. If I am not mistaken they use a stabilized version whereas OTT-JAX does not.

Moreover, if I understand correctly, OTT-JAX does not incorporate the scale of the marginals into the stopping criterion. This could be easily adapted and might help the user.

MUCDK avatar Aug 17 '22 16:08 MUCDK

Thanks for checking this more closely, I must admit that part of the code was not extensively reviewed, so i am very grateful for you taking a closer look!!

here total_sum corresponds to $\langle\alpha \otimes \beta, e^{\frac{f\oplus g-C}{\varepsilon}}\rangle$, since the elementwise product of these two matrices, i.e. $[\alpha_i \beta_j e^{\frac{f_i+ g_j-C_{ij}}{\varepsilon}}]_{ij}$ corresponds to the transport.

Since that term has a minus in front of it in equation 24, I think this should still be a $-\varepsilon \times$ total_sum.

Similarly the jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) corresponds to $\langle\alpha \otimes \beta, 1\rangle$ and that should be added (with +) because of the two -. So with a quick check I think that's still a + there in front of ot_prob.epsilon, anything mistaken on my end? Thanks again for checking!

As for marginals, I think you are right, there is definitely some scaling factor needed here in the stopping criterion. I think that's a very good idea. For instance, we could have a tolerance equal by default to jnp.sum(a) * tolerance. I hope this does not have complicated side effects when optimizing over $a$ in an unbalanced setting, but I think that's reasonable.

marcocuturi avatar Aug 18 '22 03:08 marcocuturi

Hi Marco,

Thanks, yeah, I spotted my mistake and your reasoning makes perfect sense!

MUCDK avatar Aug 18 '22 12:08 MUCDK