ott
ott copied to clipboard
`power=1.0` differentiation is unstable due to instability of differentiating distance when points are nearby
Describe the bug
We chose originally to implement the squared Euclidean distance as jnp.sum(x**2,axis=-1) + jnp.sum(y**2,axis=-1) - 2 * jnp.vdot(x,y)
. Although this works with power=2.0
(which leaves it unchanged) this fails to differentiate elegantly when x~=y
and power=1.0
because of 0/0 mishandling. Some custom differentiation is needed in that case, and it's already implemented in jnp.linalg.norm
.
To Reproduce
g1 = jax.grad(lambda x,y: jnp.linalg.norm(x-y))
g2 = jax.grad(lambda x,y: (x**2 + y**2 - 2*x*y) ** 0.5)
for eps in range(-3,-9,-1):
print(g1(1.+1*10**eps,1.))
print(g2(1.+1*10**eps,1.))
print('--')
Expected behavior Here we would have expected similar behaviour, but what I see is
1.0
1.0240479
--
1.0
nan
--
1.0
nan
--
1.0
nan
--
1.0
nan
--
nan
nan
--
As a result, I propose to call the current (misnamed) Euclidean
cost function as SqEuclidean
and leave it as it is, default to power=1.0
for all PointCloud
, and introduce the Euclidean
cost function, that will use jnp.linalg.norm
. This will also clean up the contradiction that power=2.0
be the current default, which is messy for anything that's not SqEuclidean
.