POT icon indicating copy to clipboard operation
POT copied to clipboard

[WIP] New alpha API for OT solver (with pre-computed ground cost matrix)

Open rflamary opened this issue 1 year ago • 4 comments

Types of changes

I implement the new POT API for general OT solvers. It comes with the new function ot.solve that can be used to solve exact, regularized and unbalanced OT depending on the parameter it receives and returns a new OTResult class that contains all information that can be useful to the user (OT value, OT plan, OT marginals and OT dual potentials).

TODO :

  • [x] Implement base OTResult class
  • [x] Implement ot.solve function with exact OT and exact unbalanced OT
  • [x] Implement call to all regularized a solvers (sinkhorn, L2)
  • [x] Add TV as type of unbalanced OT
  • [ ] Write the documentation for the solve function
  • [ ] Add all tests to ensure good code coverage

Motivation and context / Related issue

The API has been discussed with @agramfort @jeanfeydy @hichamjanati and @ncourty and aim at providing a general solver mechanism for the most common OT problems.

import numpy as np
import ot

#%% Data

np.random.seed(42)

xs = np.random.randn(5,2)
xt = np.random.randn(6,2)

M = ot.dist(xs,xt)

a = ot.unif(5)
b = ot.unif(6)


#Solve  exact ot
sol = ot.solve(M)

# get the results
G = sol.plan # OT plan
ot_loss = sol.value # OT objective fucntion value
ot_loss_linear = sol.value_linear # OT value for linera term np.sum(sol.plan*M)
alpha, beta = sol.potentials # dual potentials

# direct plan and loss computation
G = ot.solve(M).plan
ot_loss = ot.solve(M).value

# OT exact with marginals a/b
sol2 = ot.solve(M, a, b)

# regularized OT
sol_rkl = ot.solve(M, a, b, reg=1) # KL regularization
sol_rentropy = ot.solve(M, a, b, reg=1, reg_type='entropy') # enropic reg (Sinkhorn paper) only change the loss
sol_rl2 = ot.solve(M, a, b, reg=1, reg_type='L2')



# Exact unbalanced OT with diferent penalizations
sol_utv = ot.solve(M, a, b, unbalanced=10, unbalanced_type='TV')
sol_ul2 = ot.solve(M, a, b, unbalanced=10, unbalanced_type='L2')
sol_ukl = ot.solve(M, a, b, unbalanced=10, unbalanced_type='KL')


# Unbalanced and regularized OT 
sol_rkl_ukl = ot.solve(M, a, b, reg=10, unbalanced=10) # KL + KL
sol_rl2_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='L2', unbalanced_type='L2') # L2 + L2
sol_rkl_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='KL', unbalanced_type='L2') # KL + L2
sol_rl2_ukl = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='L2', unbalanced_type='KL') # KL + L2

sol_rentropy_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='entropy', unbalanced_type='L2') # KL + L2


The code has been written in part by @jeanfeydy in an HackMD file used during the discussion.

How has this been tested (if it applies)

New tests for all possible product of parameters

PR checklist

  • [x] I have read the CONTRIBUTING document.
  • [ ] The documentation is up-to-date with the changes I made (check build artifacts).
  • [x] All tests passed, and additional code has been covered with new tests.
  • [ ] I have added the PR and Issue fix to the RELEASES.md file.

rflamary avatar Jul 18 '22 09:07 rflamary

Codecov Report

Merging #388 (06e19ed) into master (8490196) will increase coverage by 0.23%. The diff coverage is 97.95%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #388      +/-   ##
==========================================
+ Coverage   94.02%   94.26%   +0.23%     
==========================================
  Files          22       23       +1     
  Lines        5924     6203     +279     
==========================================
+ Hits         5570     5847     +277     
- Misses        354      356       +2     

codecov[bot] avatar Jul 18 '22 12:07 codecov[bot]

Hi @rflamary,

Fantastic :-) I have also started working on implementing this API last month (in the ot_api branch of GeomLoss), and this will be the project of the summer alongside a clean benchmarking platform for OT solvers that follows the structure of ann-benchmarks.

I'd be happy to come and visit you in Saclay in September to synchronize all of this - and we can have a visio call more or less anytime in August if you're not offline.

In any case, have a good summer and see you soon! Jean

jeanfeydy avatar Jul 18 '22 21:07 jeanfeydy

@jeanfeydy this is great I will also work on that during the summer I think and we definitely want to talk. Especially since I changed the API a little bit and I don't really like having the same OTResult class for traditional OT problem (value+plan+...) and for OT barycenter when the result is a distribution (masses+ support position). August is fine for a virtual meeting and of course we should meetup in saclay at the beginning of the academic year (end of September the beginning will be hectic for me)

rflamary avatar Jul 19 '22 06:07 rflamary

Ok perfect, see you soon! (And I agree with the barycenter change, for sure.)

jeanfeydy avatar Jul 19 '22 09:07 jeanfeydy

I think its OK now (doc for fyunction is done exemples and example update will be done later when doing the full API V2 release).

@agramfort care for a quick code reveiw?

rflamary avatar Dec 06 '22 11:12 rflamary

it would be really cool to have this new API showed early in this page https://pythonot.github.io/quickstart.html

agramfort avatar Dec 15 '22 08:12 agramfort

I agree but it is not ready yet because we need to implement also GW, OT on sample and OT on grid which is a lot of work in addition to the doc...

I am going for a feature/bug 8.3 release shortly (we have many bugs in 8.2) where teh new API is not yet promoted (or only with beta status) and then we go twoard POT 1.0 with a big documentation and exemple revamp.

rflamary avatar Dec 15 '22 08:12 rflamary