POT icon indicating copy to clipboard operation
POT copied to clipboard

Demo of multidimensional Wasserstein distance

Open lmmx opened this issue 1 year ago • 2 comments

Hi, I've just come across your library while studying Wasserstein distances and OT, and was hoping to compute some simple examples like the one shown in Lilian Weng's blog post:

  • P = [3, 2, 1, 4]
  • Q = [1, 2, 4, 3]
  • EMD = 5

Screenshot from 2022-08-22 01-53-53

However after trying to rewrite the domain adaptation (image recolouring) example I ended up with a cost_ attribute in my ot_emd object filled with 14.0 not 5, and could not see anywhere else the distance might be hiding.

im1 = np.array([[[3,2,1,4]]]).astype(float)
im2 = np.array([[[1,2,4,3]]]).astype(float)

X1 = im1.reshape(im1.shape[0] * im1.shape[1], -1)
X2 = im2.reshape(im2.shape[0] * im2.shape[1], -1)

rng = np.random.RandomState(42)
n_samples = 2

idx1 = rng.randint(X1.shape[0], size=(n_samples,))
idx2 = rng.randint(X2.shape[0], size=(n_samples,))

Xs = X1[idx1, :]
Xt = X2[idx2, :]

ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)

transp_Xs_emd = ot_emd.transform(Xs=X1)
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)

im1t = np.clip(transp_Xs_emd.reshape(im1.shape), 0, 1)
im2t = np.clip(transp_Xt_emd.reshape(im2.shape), 0, 1)

This gives

>>> ot_emd.cost_
array([[14., 14.],
       [14., 14.]])
>>> ot_emd.coupling_
array([[0. , 0.5],
       [0.5, 0. ]])
>>> ot_emd.mu_s
array([0.5, 0.5])
>>> ot_emd.mu_t
array([0.5, 0.5])

Am I close to the correct usage here or (presumably) misunderstanding something, or overlooked something in the docs? I looked through the intro and general guide as well as around the API docs. The direct emd2 method gives 0

>>> ot.emd2([.3,.2,.1,.4], [.1,.2,.4,.3], np.flipud(np.eye(4)))
0.0

I feel like an example simple enough to be used as a toy example should be easily computed, so please excuse me for asking such a basic question!

lmmx avatar Aug 22 '22 01:08 lmmx

For me, the following code returns the Wasserstein distance as 5.0. The point is, the cost matrix M should be calculated as an index-wise distance (i.e., moving cost from 0th to 2nd is 2, from 1st to 0th is 1). Please tell me if I'm wrong :)

import numpy as np
import ot
from scipy.stats import wasserstein_distance as wasser

A = np.array([3, 2, 1, 4])
B = np.array([1, 2, 4, 3])

# for moving cost from ith index in A to jth index in B
ia = np.arange(0, A.shape[0], 1)
ib = np.arange(0, B.shape[0], 1)
n = A.shape[0]
M = ot.dist(ia.reshape((n, 1)), ib.reshape((n, 1)), 'euclidean')

# compute wasserstein distance (ot & scipy)
W = ot.emd2(A, B, M)
Wsci = wasser(A, B)
print(M) # returns [[0. 1. 2. 3.], [1. 0. 1. 2.], [2. 1. 0. 1], [3. 2. 1. 0]]
print(W, Wsci) # returns (5.0, 0.0)

hynkis avatar Aug 28 '22 08:08 hynkis

Hello,

You do not compute the same thing.

A and B in ot.emd are the weights on each sample and M is the cost matrices (here between samples with positions on integers ).

In order to compute the same thing you need to call


Wsci = wasser(ia, ib, u_weights=A, v_weights = B)

In ot.emd2 the metrics/positions of the samples are encoded in M and the weights in A,B, in the scipy function that works only in 1D the metric is encoded in u,v. For 1D solve pot also provides the following function that is much faster: https://pythonot.github.io/all.html?highlight=wasserstein_1d#ot.wasserstein_1d

rflamary avatar Aug 29 '22 06:08 rflamary