ott icon indicating copy to clipboard operation
ott copied to clipboard

`power=1.0` differentiation is unstable due to instability of differentiating distance when points are nearby

Open marcocuturi opened this issue 2 years ago • 0 comments

Describe the bug We chose originally to implement the squared Euclidean distance as jnp.sum(x**2,axis=-1) + jnp.sum(y**2,axis=-1) - 2 * jnp.vdot(x,y). Although this works with power=2.0 (which leaves it unchanged) this fails to differentiate elegantly when x~=y and power=1.0 because of 0/0 mishandling. Some custom differentiation is needed in that case, and it's already implemented in jnp.linalg.norm.

To Reproduce

g1 = jax.grad(lambda x,y: jnp.linalg.norm(x-y))
g2 = jax.grad(lambda x,y: (x**2 + y**2 - 2*x*y) ** 0.5)
for eps in range(-3,-9,-1):
  print(g1(1.+1*10**eps,1.))
  print(g2(1.+1*10**eps,1.))
  print('--')

Expected behavior Here we would have expected similar behaviour, but what I see is

1.0
1.0240479
--
1.0
nan
--
1.0
nan
--
1.0
nan
--
1.0
nan
--
nan
nan
--

As a result, I propose to call the current (misnamed) Euclidean cost function as SqEuclidean and leave it as it is, default to power=1.0 for all PointCloud, and introduce the Euclidean cost function, that will use jnp.linalg.norm. This will also clean up the contradiction that power=2.0 be the current default, which is messy for anything that's not SqEuclidean.

marcocuturi avatar Oct 18 '22 06:10 marcocuturi