ott
ott copied to clipboard
Negative loss in unbalanced Sinkhorn
Unbalanced Sinkhorn results in negative reg_ot_cost
.
import jax
from ott.core.sinkhorn import Sinkhorn
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
from ott.core import LinearProblem
n = 1000
dim = 30
rng = jax.random.PRNGKey(0)
rng, *rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (n, dim)) + 0.1
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (n,))
a_scaled = a / jnp.sum(a)
b_scaled = b / jnp.sum(b)
lp = LinearProblem(PointCloud(x, y, epsilon=1), a_scaled, b_scaled, tau_a=0.5, tau_b=0.6)
s = Sinkhorn()(lp)
print(s.reg_ot_cost)
The regularised OT cost is -8.937
. The more unbalanced, the more negative the OT cost is, which should not be the case in general.
Installation ott-jax: from source.
Hi Dominik, thanks for this.
If you look at the primal cost, which is returned by reg_ot_cost
, this includes terms in KLs (those should >0) but also a contribution term corresponding to minus the entropy.
In OTT, we compute the objective using the dual formulation
https://github.com/ott-jax/ott/blob/d7521fdbe3e461a2192c8fc1e4080558dd453ee3/ott/core/sinkhorn.py#L142
and that also includes some negative contributions, notably in the div_a
and div_b
terms, and the total_sum
term substratcted at the end.
So, having a negative ent_reg_cost
is not a contradiction per se. However, this does not preclude a bug. Have you looked at the transportation matrix that is returned? do you see entries that are unusual?
Hi Marco,
Thanks for your response. I might be wrong but reading equation 24 in https://arxiv.org/pdf/1910.12958.pdf the term multiplied by epsilon in the entropic penalization should be non-positive because f_i + g_j <= C_{ij}
and hence every element of the right matrix in the inner product is non-positive whereas each term in the left matrix is non-negative.
This holds true for total_sum
in the code, in the example above total_sum = -11.217885
.
On the other hand, we add this term in the return statement, i.e. the return value (https://github.com/ott-jax/ott/blob/d7521fdbe3e461a2192c8fc1e4080558dd453ee3/ott/core/sinkhorn.py#L198) is div_a + div_b + ot_prob.epsilon * (jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum)
. Hence, I don't know where the minus before the epsilon in equation (24) is incorporated.
Hence, I would assume that the return statement starting in line 198 should read
div_a + div_b - ot_prob.epsilon * (
jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum
)
Apologies if I am wrong.
The transport matrix seems to be not completely wrong at least as for multiple scenarios the PCC with implementations from WOT and POT are larger than 0.98. If I am not mistaken they use a stabilized version whereas OTT-JAX
does not.
Moreover, if I understand correctly, OTT-JAX
does not incorporate the scale of the marginals into the stopping criterion. This could be easily adapted and might help the user.
Thanks for checking this more closely, I must admit that part of the code was not extensively reviewed, so i am very grateful for you taking a closer look!!
here total_sum
corresponds to $\langle\alpha \otimes \beta, e^{\frac{f\oplus g-C}{\varepsilon}}\rangle$, since the elementwise product of these two matrices, i.e. $[\alpha_i \beta_j e^{\frac{f_i+ g_j-C_{ij}}{\varepsilon}}]_{ij}$ corresponds to the transport.
Since that term has a minus in front of it in equation 24, I think this should still be a $-\varepsilon \times$ total_sum
.
Similarly the jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b)
corresponds to $\langle\alpha \otimes \beta, 1\rangle$ and that should be added (with +) because of the two -. So with a quick check I think that's still a +
there in front of ot_prob.epsilon
, anything mistaken on my end? Thanks again for checking!
As for marginals, I think you are right, there is definitely some scaling factor needed here in the stopping criterion. I think that's a very good idea. For instance, we could have a tolerance equal by default to jnp.sum(a) * tolerance
. I hope this does not have complicated side effects when optimizing over $a$ in an unbalanced setting, but I think that's reasonable.
Hi Marco,
Thanks, yeah, I spotted my mistake and your reasoning makes perfect sense!