POT
POT copied to clipboard
[WIP] New alpha API for OT solver (with pre-computed ground cost matrix)
Types of changes
I implement the new POT API for general OT solvers. It comes with the new function ot.solve
that can be used to solve exact, regularized and unbalanced OT depending on the parameter it receives and returns a new OTResult
class that contains all information that can be useful to the user (OT value, OT plan, OT marginals and OT dual potentials).
TODO :
- [x] Implement base OTResult class
- [x] Implement
ot.solve
function with exact OT and exact unbalanced OT - [x] Implement call to all regularized a solvers (sinkhorn, L2)
- [x] Add TV as type of unbalanced OT
- [ ] Write the documentation for the solve function
- [ ] Add all tests to ensure good code coverage
Motivation and context / Related issue
The API has been discussed with @agramfort @jeanfeydy @hichamjanati and @ncourty and aim at providing a general solver mechanism for the most common OT problems.
import numpy as np
import ot
#%% Data
np.random.seed(42)
xs = np.random.randn(5,2)
xt = np.random.randn(6,2)
M = ot.dist(xs,xt)
a = ot.unif(5)
b = ot.unif(6)
#Solve exact ot
sol = ot.solve(M)
# get the results
G = sol.plan # OT plan
ot_loss = sol.value # OT objective fucntion value
ot_loss_linear = sol.value_linear # OT value for linera term np.sum(sol.plan*M)
alpha, beta = sol.potentials # dual potentials
# direct plan and loss computation
G = ot.solve(M).plan
ot_loss = ot.solve(M).value
# OT exact with marginals a/b
sol2 = ot.solve(M, a, b)
# regularized OT
sol_rkl = ot.solve(M, a, b, reg=1) # KL regularization
sol_rentropy = ot.solve(M, a, b, reg=1, reg_type='entropy') # enropic reg (Sinkhorn paper) only change the loss
sol_rl2 = ot.solve(M, a, b, reg=1, reg_type='L2')
# Exact unbalanced OT with diferent penalizations
sol_utv = ot.solve(M, a, b, unbalanced=10, unbalanced_type='TV')
sol_ul2 = ot.solve(M, a, b, unbalanced=10, unbalanced_type='L2')
sol_ukl = ot.solve(M, a, b, unbalanced=10, unbalanced_type='KL')
# Unbalanced and regularized OT
sol_rkl_ukl = ot.solve(M, a, b, reg=10, unbalanced=10) # KL + KL
sol_rl2_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='L2', unbalanced_type='L2') # L2 + L2
sol_rkl_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='KL', unbalanced_type='L2') # KL + L2
sol_rl2_ukl = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='L2', unbalanced_type='KL') # KL + L2
sol_rentropy_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='entropy', unbalanced_type='L2') # KL + L2
The code has been written in part by @jeanfeydy in an HackMD file used during the discussion.
How has this been tested (if it applies)
New tests for all possible product of parameters
PR checklist
- [x] I have read the CONTRIBUTING document.
- [ ] The documentation is up-to-date with the changes I made (check build artifacts).
- [x] All tests passed, and additional code has been covered with new tests.
- [ ] I have added the PR and Issue fix to the RELEASES.md file.
Codecov Report
Merging #388 (06e19ed) into master (8490196) will increase coverage by
0.23%
. The diff coverage is97.95%
.
Additional details and impacted files
@@ Coverage Diff @@
## master #388 +/- ##
==========================================
+ Coverage 94.02% 94.26% +0.23%
==========================================
Files 22 23 +1
Lines 5924 6203 +279
==========================================
+ Hits 5570 5847 +277
- Misses 354 356 +2
Hi @rflamary,
Fantastic :-) I have also started working on implementing this API last month (in the ot_api branch of GeomLoss), and this will be the project of the summer alongside a clean benchmarking platform for OT solvers that follows the structure of ann-benchmarks.
I'd be happy to come and visit you in Saclay in September to synchronize all of this - and we can have a visio call more or less anytime in August if you're not offline.
In any case, have a good summer and see you soon! Jean
@jeanfeydy this is great I will also work on that during the summer I think and we definitely want to talk. Especially since I changed the API a little bit and I don't really like having the same OTResult class for traditional OT problem (value+plan+...) and for OT barycenter when the result is a distribution (masses+ support position). August is fine for a virtual meeting and of course we should meetup in saclay at the beginning of the academic year (end of September the beginning will be hectic for me)
Ok perfect, see you soon! (And I agree with the barycenter change, for sure.)
I think its OK now (doc for fyunction is done exemples and example update will be done later when doing the full API V2 release).
@agramfort care for a quick code reveiw?
it would be really cool to have this new API showed early in this page https://pythonot.github.io/quickstart.html
I agree but it is not ready yet because we need to implement also GW, OT on sample and OT on grid which is a lot of work in addition to the doc...
I am going for a feature/bug 8.3 release shortly (we have many bugs in 8.2) where teh new API is not yet promoted (or only with beta status) and then we go twoard POT 1.0 with a big documentation and exemple revamp.