POT icon indicating copy to clipboard operation
POT copied to clipboard

sinkhorn divergence appears to be calculated incorrectly

Open jacksonloper opened this issue 2 years ago • 2 comments

Describe the bug

As far as I can tell, ot.bregman.empirical_sinkhorn_divergence returns only the transport costs and ignores the regularization costs. The documentation says it returns both. This was also mentioned in #255, but the concern in that issue was whether it matches some papers (and there have been a lot of papers!).

My concern here is that the documentation of ot.bregman.empirical_sinkhorn_divergence gives a precise formula for what it does, but then the function itself appears to compute something quite different.

To Reproduce

Steps to reproduce the behavior:

  1. Run code below.
  2. Feel confused.

Code sample

import ot
import numpy as np
import scipy as sp

# setup problem
ptsA=np.r_[0:5:10j][:,None]
ptsB=np.r_[1:3:20j][:,None]
eps=1

# get distance matrices
C1=sp.spatial.distance.cdist(ptsA,ptsB)**2
C2=sp.spatial.distance.cdist(ptsA,ptsA)**2
C3=sp.spatial.distance.cdist(ptsB,ptsB)**2

# get transport plans
pot_plan1=ot.bregman.empirical_sinkhorn(ptsA,ptsB,eps)
pot_plan2=ot.bregman.empirical_sinkhorn(ptsA,ptsA,eps)
pot_plan3=ot.bregman.empirical_sinkhorn(ptsB,ptsB,eps)

# compute transport costs for sinkhorn divergence
transport_costs=np.sum(C1*pot_plan1)-.5*np.sum(C2*pot_plan2)-.5*np.sum(C3*pot_plan3)

# compute entropic costs for sinkhorn divergence
entropic_costs = np.sum(pot_plan1*np.log(pot_plan1))-.5*np.sum(pot_plan2*np.log(pot_plan2))-.5*np.sum(pot_plan3*np.log(pot_plan3))

# print results
print('transport costs'.rjust(30),transport_costs)
print('entropic costs'.rjust(30),entropic_costs)
print('sinkhorn divergence'.rjust(30),transport_costs+entropic_costs)

# compare with results form ot
print('result from ot'.rjust(30),ot.bregman.empirical_sinkhorn_divergence(ptsA,ptsB,eps))

Expected behavior

I expect the result from ot.bregman.empirical_sinkhorn_divergence to be the same as the sinkhorn divergence as I have calculated it. Instead, it seems to be identical to only the transport-cost portion of the divergence.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu
  • Python version: 3.9.7
  • How was POT installed (source, pip, conda): pip
Linux-5.10.47-linuxkit-x86_64-with-glibc2.31
Python 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:20:46) 
[GCC 9.4.0]
NumPy 1.20.3
SciPy 1.7.2
POT 0.8.2

jacksonloper avatar Jun 07 '22 13:06 jacksonloper

As far as I can tell this was acknowledged by @rflamary here and a PR with a small documentation of this behavior was proposed in #423 .(idk if it is live yet). I guess a fix/option (to include entropic contribution) should come at some point.

tlacombe avatar Apr 14 '23 07:04 tlacombe

yes acually I now provide both losses in the new ot.solve API and i plan tu provise also the proper divergence for empirical distributions when impelmenting the new api (with lazy distance evaluation).

feel free to update the sinkhon divergence fonction to include the entropy term and add an option to not compute it (although i belive that without teh entropy term we recover the "sharp" sinkhorn that is better but less/not studied)

rflamary avatar Apr 14 '23 09:04 rflamary