ott
ott copied to clipboard
Duplicate Gaussian code
Duplicate code for estimating Gaussian moments and computing transport maps for Gaussian OT problems - In tools, Gaussian https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/tools/gaussian_mixture/gaussian.py#L35 and https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/tools/gaussian_mixture/scale_tril.py#L175 - And in ICNN initialisers, https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/core/icnn.py#L137, https://github.com/ott-jax/ott/blob/5657e786c99afb808edd98e061ec316eed0942be/ott/core/icnn.py#L161
It may be worth consolidating these. It may also be worth moving Gaussian computations to /core, or creating a hierarchical structure that does not cause cyclical import errors (see e.g. https://github.com/ott-jax/ott/pull/98#discussion_r919287726)
thanks for spotting these James, and, yes, I agree, there's too many sqrtm_only
in our code :)
if i can ask another comment, i think the way we store gaussians is not ideal, or requires some debate. When doing gaussian OT, we might have to use $A^{0.5}$, $A$ or $A^{-0.5}$. When computing $W_2^2$, either the first or the second are needed (if computing $A$ to $B$ then $A^{0.5}$ is used, the other way around $A$ is used) whereas to compute the map one needs, similarly, either (1,3) or (2).
I am wondering if, as a preprocessing, we should not store the 3 of them. I know we can recover (2) from (1*1) but that sounds a bit like overkill. so maybe storing the 3 triangular parts is not too complicated? and then redefining the Bures cost (or creating a new one, e.g. BuresTril) could make sense to save memory and compute.