Wasserstein Circle distance doesn't seem correct?
Describe the bug
Forgive me if I've misunderstood what the wasserstein_circle function is supposed to do. But I would have thought that if we have
d1 = wasserstein_circle(arr1, arr2)
And then we add the same amount to both arrays (i.e. rotating both samples the same angle around the circle), then we should get the same answer.
d2 = wasserstein_circle(arr1 + delta, arr2 + delta)
assert d1 == d2
But the first example I tried fails.
To Reproduce
import numpy as np
import ot
sample1 = np.array([0.1, 0.11, 0.4, 0.6])
sample2 = np.array([0.21, 0.15, 0.7, 0.95])
d1 = ot.wasserstein_circle(sample1, sample2)
delta = 0.02
d2 = ot.wasserstein_circle(sample1 + delta, sample2 + delta)
assert d1 == d2 # fails
Expected behavior
wasserstein_circle should be rotationally symmetric, i.e. it should obey the property
d1 = wasserstein_circle(arr1, arr2)
d2 = wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
assert d1 == d2
For all real delta (up to floating point inaccuracies), because this amounts to just turning your head to the side.
Or am I just misunderstanding how the input is supposed to be represented here?
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): NixOS 24.11
- Python version: 3.12.7
- How was POT installed (source,
pip,conda): Nix
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-6.6.67-x86_64-with-glibc2.40
Python 3.12.7 (main, Oct 1 2024, 02:05:46) [GCC 13.3.0]
NumPy 1.26.4
SciPy 1.14.1
POT 0.9.4
Additional context
You are right, wasserstein_circle should be rotationally symmetric. It seems that the problem comes from the closed-form we are trying to approximate in the case p=1, which is the default case of wasserstein_circle, and does not seem to work always well. If you try with the function binary_search_circle, it should be well rotationally symmetric.
I just proposed in #PR736 to change the default behaviour of wasserstein_circle to always use the binary_search_circle, which seems to work better in general.
Okay, good to know I'm not going crazy 😅
Have you considered using hypothesis to property-test this library? It works particularly well for mathematical code with lots of clean invariants / symmetries. This property-test immediately finds a minimal counterexample:
import hypothesis.strategies as st
from hypothesis import given
import numpy as np
import ot
@given(
arr1=st.lists(
st.floats(min_value=0, max_value=1),
min_size=1, max_size=10,
),
arr2=st.lists(
st.floats(min_value=0, max_value=1),
min_size=1, max_size=10,
),
delta=st.floats(min_value=-1, max_value=1)
)
def test_wasserstein_circle_is_rotationally_symmetric(arr1, arr2, delta):
arr1 = np.array(arr1)
arr2 = np.array(arr2)
[d1] = ot.wasserstein_circle(arr1, arr2)
[d2] = ot.wasserstein_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
assert np.isclose(d1, d2, atol=1e-3)
E Falsifying example: test_wasserstein_circle_is_rotationally_symmetric(
E arr1=[0.0],
E arr2=[0.125],
E delta=0.5,
E )
The binary_search_circle function passes this test 🙂
The binary_search_circle function passes this test 🙂
Actually not quite, this gives a warning:
E RuntimeWarning: invalid value encountered in divide
E Falsifying example: test_binary_search_circle_is_rotationally_symmetric(
E arr1=[0.5],
E arr2=[0.25],
E delta=0.5,
E )
Standalone example:
import numpy as np
import ot
import warnings
warnings.filterwarnings("error")
arr1 = np.array([0.5])
arr2 = np.array([0.25])
delta = 0.5
[d1] = ot.binary_search_circle(arr1, arr2)
[d2] = ot.binary_search_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
assert np.isclose(d1, d2, atol=1e-3)
---------------------------------------------------------------------------
RuntimeWarning Traceback (most recent call last)
Cell In[16], line 9
6 delta = 0.5
8 [d1] = ot.binary_search_circle(arr1, arr2)
----> 9 [d2] = ot.binary_search_circle((arr1 + delta) % 1, (arr2 + delta) % 1)
11 assert np.isclose(d1, d2, atol=1e-3)
File /nix/store/nlrk6vd18px7gg88yjaxgn0l8dalpbcg-python3-3.12.7-env/lib/python3.12/site-packages/ot/lp/solver_1d.py:737, in binary_search_circle(u_values, v_values, u_weights, v_weights, p, Lm, Lp, tm, tp, eps, require_sort, log)
734 Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
736 mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
--> 737 tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
738 done[nx.prod(mask, axis=-1) > 0] = 1
739 elif nx.any(1 - done):
RuntimeWarning: invalid value encountered in divide
Not sure if this is a real bug but thought it was worth recording.
Sorry, it's my bad ^^. I shouldn't have use this implementation for the default behaviour...
I didn't know the hypothesis library. It seems very useful! I am reassured that binary_search_circle seems to pass the test...
Concerning the binary_search_circle, thank you for noticing this little bug. I think this warning arises as dCptm - dCmtp==0, but it is not a problem as the corresponding values are not updated thanks to the condition mask_end>0. However, it would be nice to avoid to raise a warning in this case. I will see if I can fix this.
Closing since this was fixed in #736