POT icon indicating copy to clipboard operation
POT copied to clipboard

Wasserstein Circle distance doesn't seem correct?

Open ckp95 opened this issue 7 months ago • 4 comments

Describe the bug

Forgive me if I've misunderstood what the wasserstein_circle function is supposed to do. But I would have thought that if we have

d1 = wasserstein_circle(arr1, arr2)

And then we add the same amount to both arrays (i.e. rotating both samples the same angle around the circle), then we should get the same answer.

d2 = wasserstein_circle(arr1 + delta, arr2 + delta)
assert d1 == d2

But the first example I tried fails.

To Reproduce

import numpy as np
import ot

sample1 = np.array([0.1, 0.11, 0.4, 0.6])
sample2 = np.array([0.21, 0.15, 0.7, 0.95])

d1 = ot.wasserstein_circle(sample1, sample2)

delta = 0.02

d2 = ot.wasserstein_circle(sample1 + delta, sample2 + delta)

assert d1 == d2 # fails

Expected behavior

wasserstein_circle should be rotationally symmetric, i.e. it should obey the property

d1 = wasserstein_circle(arr1, arr2)
d2 = wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
assert d1 == d2

For all real delta (up to floating point inaccuracies), because this amounts to just turning your head to the side.

Or am I just misunderstanding how the input is supposed to be represented here?

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): NixOS 24.11
  • Python version: 3.12.7
  • How was POT installed (source, pip, conda): Nix

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.6.67-x86_64-with-glibc2.40
Python 3.12.7 (main, Oct  1 2024, 02:05:46) [GCC 13.3.0]
NumPy 1.26.4
SciPy 1.14.1
POT 0.9.4

Additional context

ckp95 avatar May 20 '25 18:05 ckp95

You are right, wasserstein_circle should be rotationally symmetric. It seems that the problem comes from the closed-form we are trying to approximate in the case p=1, which is the default case of wasserstein_circle, and does not seem to work always well. If you try with the function binary_search_circle, it should be well rotationally symmetric.

I just proposed in #PR736 to change the default behaviour of wasserstein_circle to always use the binary_search_circle, which seems to work better in general.

clbonet avatar May 20 '25 18:05 clbonet

Okay, good to know I'm not going crazy 😅

Have you considered using hypothesis to property-test this library? It works particularly well for mathematical code with lots of clean invariants / symmetries. This property-test immediately finds a minimal counterexample:

import hypothesis.strategies as st
from hypothesis import given
import numpy as np
import ot


@given(
    arr1=st.lists(
        st.floats(min_value=0, max_value=1),
        min_size=1, max_size=10,
    ),
    arr2=st.lists(
        st.floats(min_value=0, max_value=1),
        min_size=1, max_size=10,
    ),
    delta=st.floats(min_value=-1, max_value=1)
)
def test_wasserstein_circle_is_rotationally_symmetric(arr1, arr2, delta):
    arr1 = np.array(arr1)
    arr2 = np.array(arr2)

    [d1] = ot.wasserstein_circle(arr1, arr2)
    [d2] = ot.wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)

    assert np.isclose(d1, d2, atol=1e-3)
E       Falsifying example: test_wasserstein_circle_is_rotationally_symmetric(                                               
E           arr1=[0.0],                                                                                                      
E           arr2=[0.125],                                                                                                    
E           delta=0.5,                                                                                                       
E       )  

The binary_search_circle function passes this test 🙂

ckp95 avatar May 20 '25 19:05 ckp95

The binary_search_circle function passes this test 🙂

Actually not quite, this gives a warning:

E               RuntimeWarning: invalid value encountered in divide
E               Falsifying example: test_binary_search_circle_is_rotationally_symmetric(
E                   arr1=[0.5],
E                   arr2=[0.25],
E                   delta=0.5,
E               )

Standalone example:

import numpy as np
import ot
import warnings
warnings.filterwarnings("error")

arr1 = np.array([0.5])
arr2 = np.array([0.25])
delta = 0.5

[d1] = ot.binary_search_circle(arr1, arr2)
[d2] = ot.binary_search_circle((arr1 + delta) % 1, (arr2 + delta) % 1)

assert np.isclose(d1, d2, atol=1e-3)
---------------------------------------------------------------------------
RuntimeWarning                            Traceback (most recent call last)
Cell In[16], line 9
      6 delta = 0.5
      8 [d1] = ot.binary_search_circle(arr1, arr2)
----> 9 [d2] = ot.binary_search_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
     11 assert np.isclose(d1, d2, atol=1e-3)

File /nix/store/nlrk6vd18px7gg88yjaxgn0l8dalpbcg-python3-3.12.7-env/lib/python3.12/site-packages/ot/lp/solver_1d.py:737, in binary_search_circle(u_values, v_values, u_weights, v_weights, p, Lm, Lp, tm, tp, eps, require_sort, log)
    734     Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
    736     mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
--> 737     tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
    738     done[nx.prod(mask, axis=-1) > 0] = 1
    739 elif nx.any(1 - done):

RuntimeWarning: invalid value encountered in divide

Not sure if this is a real bug but thought it was worth recording.

ckp95 avatar May 20 '25 19:05 ckp95

Sorry, it's my bad ^^. I shouldn't have use this implementation for the default behaviour...

I didn't know the hypothesis library. It seems very useful! I am reassured that binary_search_circle seems to pass the test...

Concerning the binary_search_circle, thank you for noticing this little bug. I think this warning arises as dCptm - dCmtp==0, but it is not a problem as the corresponding values are not updated thanks to the condition mask_end>0. However, it would be nice to avoid to raise a warning in this case. I will see if I can fix this.

clbonet avatar May 20 '25 19:05 clbonet

Closing since this was fixed in #736

rflamary avatar Oct 21 '25 11:10 rflamary