ott icon indicating copy to clipboard operation
ott copied to clipboard

Unbalanced FGW doesn't converge when margins are provided

Open selmanozleyen opened this issue 10 months ago • 3 comments

Describe the bug For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677

Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after https://github.com/ott-jax/ott/commit/41906a2a1ade19aa154189fabd7c159a160c9bf3

To Reproduce

import numpy as np
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.solvers.quadratic import solve


# Generating random data for x and y
x = np.random.rand(96, 2)  # 96 points in 2D
y = np.random.rand(96, 2)  # Another 96 points in 2D

# Create PointCloud instances
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, y)

# a and b are vectors of ones with lengths matching the number of points in x and y, respectively
a = jnp.ones(x.shape[0])
b = jnp.ones(y.shape[0])

# Call solve function with the specified parameters
solve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
      fused_penalty=1.0, epsilon=1.0, a=a, b=b)

selmanozleyen avatar Apr 17 '24 09:04 selmanozleyen

Hi @selmanozleyen , this seems to come from numerical imprecisions; more specifically, the NaNs come directly from initialization here, where marginal_1 is an array of all 0s (leads to a transport mass of 0), and later to the rescaling factor to be NaN. I will take a look whether there's more numerically stable way of computing this, however simply using

a = jnp.ones(x.shape[0]) / x.shape[0]
b = jnp.ones(y.shape[0]) / y.shape[0]

solves to numerical precision issues.

michalk8 avatar Apr 24 '24 09:04 michalk8

@michalk8, as you said when I normalize it works. But when they don't sum to 1 it still doesn't work in many cases. For example see the cases below. I'd assume unbalanced ot to not expect marginals sum to 1

a = np.ones(x.shape[0])*2
a[0:4] = 1
b = np.ones(y.shape[0])*2
b[0:4] = 1
# or 
a = np.ones(x.shape[0])*2
b = np.ones(y.shape[0])*2

selmanozleyen avatar May 14 '24 08:05 selmanozleyen

Thanks @selmanozleyen . I think what's happening here is a problem of scales. Although it may seem dividing/multiplying a/b by a constant should have no bearing on the optimization, in the case of entropic GW this is likely not the case because of the interplay with other parameters (notably epsilon but also more generally the scale of the cost matrix, since the unbalanced problem adds a KL term.

Tangentially related: I think the converged flag in GW was bugged, as discussed in https://github.com/ott-jax/ott/pull/566

marcocuturi avatar Jul 30 '24 03:07 marcocuturi