POT icon indicating copy to clipboard operation
POT copied to clipboard

[WIP] Implementation of FUGW and UCOOT

Open 6Ulm opened this issue 10 months ago • 4 comments

Types of changes

This PR is dedicated to the implementation of Fused Unbalanced GW and (Fused) Unbalanced COOT. Since their structures, it is enough to write a common template, then write a wrapper for each divergence. More precisely, we create a method called fused_unbalanced_cross_spaces_divergence , in which

  • reg_type="independent" corresponds to (Fused) UCOOT. This yields fused_unbalanced_co_optimal_transport method.
  • reg_type="joint" corresponds to FUGW. This yields fused_unbalanced_gromov_wasserstein method.

We also allow for unregularized approximation of FUGW and UCOOT, i.e. $\varepsilon = 0$, thanks to the Majorization-Minization ot.unbalanced.mm_unbalanced and ot.unbalanced.lbfgsb_unbalanced L-BFGS-B methods.

This implementation also allows for $2$ types of marginal penalization: Kullback-Leibler divergence and squared L2 norm. We also allow the cost to be sub-differentiable w.r.t the input matrices and reference distributions. This is implemented in fused_unbalanced_co_optimal_transport2 and fused_unbalanced_gromov_wasserstein methods.

Motivation and context / Related issue

How has this been tested (if it applies)

PR checklist

  • [x] I have read the CONTRIBUTING document.
  • [x] The documentation is up-to-date with the changes I made (check build artifacts).
  • [x] All tests passed, and additional code has been covered with new tests.
  • [x] I have added the PR and Issue fix to the RELEASES.md file.

6Ulm avatar Apr 04 '24 12:04 6Ulm

Codecov Report

Attention: Patch coverage is 97.66147% with 21 lines in your changes missing coverage. Please review.

Project coverage is 96.72%. Comparing base (24ad25c) to head (3d4bead).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #617      +/-   ##
==========================================
+ Coverage   96.67%   96.72%   +0.04%     
==========================================
  Files          85       88       +3     
  Lines       16956    17854     +898     
==========================================
+ Hits        16392    17269     +877     
- Misses        564      585      +21     

codecov[bot] avatar Jun 25 '24 22:06 codecov[bot]

image

We are not merging until the tests are back to reasonable time (17-20 min): 93 minutes is crazy!

potential stuff to make them shorter: use smaller problems and do not test tf backend that is crazy slow (see output )

rflamary avatar Jul 03 '24 06:07 rflamary

Indeed Thank you Rémi, i missed that ! It seems that tests are performed with very small number of samples already, you could indeed consider skipping tf (and potentially jax if still too long). You should also control the convergence tolerance. I also believe that many test functions could be merged into one without increasing the number of times solvers are used, e.g init_plans and init_duals can be covered in test_sanity by simply setting (None, same init dual than by default in the function) in one call, and (same init plan thanby default in the function, none) in another call.

cedricvincentcuaz avatar Jul 03 '24 09:07 cedricvincentcuaz

To clarify, the test sanity aims to test if FUGW(X, X) = 0 numerically, and we can recover the identity matrix as OT plan. In this case, we only consider unregularized case (so entropic regularization eps = 0)

By contrast, in test_init_duals/plans, we consider both regularized and unregularized cases. So, we can't really merge these tests into one.

To confirm the comment of Cédric, the slowness is 100% due to the tensorflow backend. Will check if so is jax.

6Ulm avatar Jul 08 '24 19:07 6Ulm