ott icon indicating copy to clipboard operation
ott copied to clipboard

`rank2` lr initializer not reproducible between 0.4.3 and 0.4.4

Open giovp opened this issue 2 years ago • 2 comments

Describe the bug I have been trying to reproduce some results using the rank2 lr initializer without success. Here's the tests I run

To Reproduce

LRFGW

Data and modules

Details
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
import numpy as np

from ott.geometry import pointcloud
from ott.problems.quadratic import quadratic_problem
from ott.solvers.quadratic import gromov_wasserstein
from jax import config

import ott
print(ott.__version__)

config.update("jax_enable_x64", True)

def create_points(rng: jax.Array, n: int, m: int, d1: int, d2: int):
    rngs = jax.random.split(rng, 5)
    x = jax.random.uniform(rngs[0], (n, d1))
    y = jax.random.uniform(rngs[1], (m, d2))
    a = jax.random.uniform(rngs[2], (n,))
    b = jax.random.uniform(rngs[3], (m,))
    a = a / jnp.sum(a)
    b = b / jnp.sum(b)
    z = jax.random.uniform(rngs[4], (m, d1))
    return x, y, a, b, z


rng = jax.random.PRNGKey(0)
n, m, d1, d2 = 10_000, 15_000, 200, 300
x, y, a, b, z = create_points(rng, n, m, d1, d2)


geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, z)
prob = quadratic_problem.QuadraticProblem(
    geom_xx,
    geom_yy,
    geom_xy=geom_xy,
    a=a,
    b=b,
    fused_penalty=1.0,
)

ott 0.4.3

init = "random" # or "rank2"
solver = gromov_wasserstein.GromovWasserstein(rank=100,initializer=init)
ot_gwlr = solver(prob)
image image

ott 0.4.4

init = "random" # or "rank2"
solver = gromov_wasserstein_lr.LRGromovWasserstein(rank=100,initializer=init)
ot_gwlr = solver(prob)
image image

same happens in ott 0.4.5 and also for LR Sinkhorn (see below)

image image image image

giovp avatar Jan 10 '24 12:01 giovp

Hi @giovp , thanks for spotting this! I think it has more to do with the refactoring of LR Sinkhorn/GW, will look into this and see what went wrong. In terms of quantitative results, do you see worse results in terms of the OT maps?

michalk8 avatar Jan 11 '24 16:01 michalk8

thanks for taking a look @michalk8 !

In terms of quantitative results, do you see worse results in terms of the OT maps?

yes, as in costs and OT maps are constants throughout iterations

giovp avatar Jan 12 '24 08:01 giovp

I think this can be closed as was fixed in #494

giovp avatar Jul 09 '24 07:07 giovp