ott icon indicating copy to clipboard operation
ott copied to clipboard

Duplicate Gaussian code

Open JTT94 opened this issue 2 years ago • 1 comments

Duplicate code for estimating Gaussian moments and computing transport maps for Gaussian OT problems - In tools, Gaussian https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/tools/gaussian_mixture/gaussian.py#L35 and https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/tools/gaussian_mixture/scale_tril.py#L175 - And in ICNN initialisers, https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/core/icnn.py#L137, https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/core/icnn.py#L161

It may be worth consolidating these. It may also be worth moving Gaussian computations to /core, or creating a hierarchical structure that does not cause cyclical import errors (see e.g. https://github.com/ott-jax/ott/pull/98#discussion_r919287726)

JTT94 avatar Aug 19 '22 10:08 JTT94

thanks for spotting these James, and, yes, I agree, there's too many sqrtm_only in our code :)

if i can ask another comment, i think the way we store gaussians is not ideal, or requires some debate. When doing gaussian OT, we might have to use $A^{0.5}$, $A$ or $A^{-0.5}$. When computing $W_2^2$, either the first or the second are needed (if computing $A$ to $B$ then $A^{0.5}$ is used, the other way around $A$ is used) whereas to compute the map one needs, similarly, either (1,3) or (2).

I am wondering if, as a preprocessing, we should not store the 3 of them. I know we can recover (2) from (1*1) but that sounds a bit like overkill. so maybe storing the 3 triangular parts is not too complicated? and then redefining the Bures cost (or creating a new one, e.g. BuresTril) could make sense to save memory and compute.

marcocuturi avatar Aug 19 '22 15:08 marcocuturi