`rank2` lr initializer not reproducible between 0.4.3 and 0.4.4
Describe the bug
I have been trying to reproduce some results using the rank2 lr initializer without success. Here's the tests I run
To Reproduce
LRFGW
Data and modules
Details
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers.quadratic import gromov_wasserstein
from jax import config
import ott
print(ott.__version__)
config.update("jax_enable_x64", True)
def create_points(rng: jax.Array, n: int, m: int, d1: int, d2: int):
rngs = jax.random.split(rng, 5)
x = jax.random.uniform(rngs[0], (n, d1))
y = jax.random.uniform(rngs[1], (m, d2))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a = a / jnp.sum(a)
b = b / jnp.sum(b)
z = jax.random.uniform(rngs[4], (m, d1))
return x, y, a, b, z
rng = jax.random.PRNGKey(0)
n, m, d1, d2 = 10_000, 15_000, 200, 300
x, y, a, b, z = create_points(rng, n, m, d1, d2)
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, z)
prob = quadratic_problem.QuadraticProblem(
geom_xx,
geom_yy,
geom_xy=geom_xy,
a=a,
b=b,
fused_penalty=1.0,
)
ott 0.4.3
init = "random" # or "rank2"
solver = gromov_wasserstein.GromovWasserstein(rank=100,initializer=init)
ot_gwlr = solver(prob)
ott 0.4.4
init = "random" # or "rank2"
solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=100,initializer=init)
ot_gwlr = solver(prob)
same happens in ott 0.4.5 and also for LR Sinkhorn (see below)
Hi @giovp , thanks for spotting this! I think it has more to do with the refactoring of LR Sinkhorn/GW, will look into this and see what went wrong. In terms of quantitative results, do you see worse results in terms of the OT maps?
thanks for taking a look @michalk8 !
In terms of quantitative results, do you see worse results in terms of the OT maps?
yes, as in costs and OT maps are constants throughout iterations
I think this can be closed as was fixed in #494