ott
ott copied to clipboard
Unbalanced FGW doesn't converge when margins are provided
Describe the bug For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677
Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after https://github.com/ott-jax/ott/commit/41906a2a1ade19aa154189fabd7c159a160c9bf3
To Reproduce
import numpy as np
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.solvers.quadratic import solve
# Generating random data for x and y
x = np.random.rand(96, 2) # 96 points in 2D
y = np.random.rand(96, 2) # Another 96 points in 2D
# Create PointCloud instances
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, y)
# a and b are vectors of ones with lengths matching the number of points in x and y, respectively
a = jnp.ones(x.shape[0])
b = jnp.ones(y.shape[0])
# Call solve function with the specified parameters
solve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
fused_penalty=1.0, epsilon=1.0, a=a, b=b)
Hi @selmanozleyen , this seems to come from numerical imprecisions; more specifically, the NaNs come directly from initialization here, where marginal_1
is an array of all 0s (leads to a transport mass of 0), and later to the rescaling factor to be NaN.
I will take a look whether there's more numerically stable way of computing this, however simply using
a = jnp.ones(x.shape[0]) / x.shape[0]
b = jnp.ones(y.shape[0]) / y.shape[0]
solves to numerical precision issues.
@michalk8, as you said when I normalize it works. But when they don't sum to 1 it still doesn't work in many cases. For example see the cases below. I'd assume unbalanced ot to not expect marginals sum to 1
a = np.ones(x.shape[0])*2
a[0:4] = 1
b = np.ones(y.shape[0])*2
b[0:4] = 1
# or
a = np.ones(x.shape[0])*2
b = np.ones(y.shape[0])*2
Thanks @selmanozleyen . I think what's happening here is a problem of scales. Although it may seem dividing/multiplying a/b
by a constant should have no bearing on the optimization, in the case of entropic GW this is likely not the case because of the interplay with other parameters (notably epsilon
but also more generally the scale of the cost matrix, since the unbalanced problem adds a KL term.
Tangentially related: I think the converged
flag in GW
was bugged, as discussed in https://github.com/ott-jax/ott/pull/566