POT icon indicating copy to clipboard operation
POT copied to clipboard

Instability in Sinkhorn Knopp converge

Open sarahboufelja opened this issue 3 years ago • 1 comments

Describe the bug

I am using the Sinkhorn and Sinkhorn with Group Lasso implementations in OT package to reproduce the results in this paper: "Optimal Transport for Domain Adaptation", by Nicolas Courty et al. However, if I run the same following code for a few times, I get inconsistent convergence results:

ot.da.SinkhornLpl1Transport(reg_e = 10, reg_cl = 1e0) 
ot.da.SinkhornTransport(reg_e = 100)

sometimes the same code on the same data, converges with no errors and sometimes the algorithm fails to converge. I did try different Reg rates but I don't want to increase the rate significantly, as ths would obviously lead to a uniform mapping. Is this a known convergence issue with the Sinkhorn implementation ? How to choose the right the reg. rate in your opinion? With Cross-validation?

To Reproduce

Steps to reproduce the behavior:

  1. Extract the Decaf Features from the 6th and 7th layer in the pre-trained AlexNet, for both MNIST and USPS data
  2. Use: ot.da.SinkhornLpl1Transport(reg_e = 10, reg_cl = 1e0) and ot.da.SinkhornTransport(reg_e = 100) to process the optimal mapping between MNIST and USPS.
  3. Change the regularisation rate from 1e-3 to 100
  4. Run each experiment 10 times to assess the consistency of convergence.

Screenshots

Code sample

Expected behavior

I ma expecting the Sinkhorn algorithm to consistently converge to the optimal coupling.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux):MACOS
  • Python version:3.9
  • How was POT installed (source, pip, conda): pip3 install
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version:
    • GPU models and configuration:
    • Any other relevant information:

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Additional context

sarahboufelja avatar Dec 10 '21 21:12 sarahboufelja

sorry for long time reply. Yes, regularization values should be chosen with care, as it depends closely to the nature of data(and cost) at hand. Yet, the behavior should be consistent between several runs with exactly the same data and regularization parameters. If it is not the case, could you set up a small running example, that does not require extra steps or data, we could work upon ?

ncourty avatar Feb 01 '22 23:02 ncourty