POT icon indicating copy to clipboard operation
POT copied to clipboard

empirical_sinkhorn_divergence doesn't have a grad_fn

Open gabrielsantosrv opened this issue 3 years ago • 4 comments

Describe the bug

I am trying to use empirical_sinkhorn_divergence as a loss function in pytorch, but the returned tensor does not have a grad_fn, so the gradient can't be propagated.

Code sample

loss = ot.bregman.empirical_sinkhorn_divergence(source, target, 1)

Expected behavior

Return a tensor with a grad_fn.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.8.13
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version: 11.2
    • GPU models and configuration: Quadro RTX 8000

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.4.0-73-generic-x86_64-with-glibc2.17
Python 3.8.13 (default, Mar 28 2022, 11:38:47) 
[GCC 7.5.0]
NumPy 1.21.6
SciPy 1.8.1
POT 0.8.2

gabrielsantosrv avatar Aug 09 '22 20:08 gabrielsantosrv

Good catch! The backend for the sinkhorn divergence is not properly done so it goes to numpy and back losing gradient information (and a lot of interest ;)).

We will look into it in the meantime i suggest that you code a function that uses ot.dist to compute distances and ot.sinkhorn2 to return differentiable losses, an sums the three terms. It can be done in a few lines of code.

rflamary avatar Aug 10 '22 06:08 rflamary

Since I am using a GPU to run my experiments, do you think it is better use ot.sinkhorn2 with parameter method="sinkhorn_log" to compute the sinkhorn divergence?

gabrielsantosrv avatar Aug 10 '22 15:08 gabrielsantosrv

yes it should be more numerically stable (slightly slower than traditional sinkhorn though)

rflamary avatar Aug 10 '22 15:08 rflamary

That's ok, thank you :D

gabrielsantosrv avatar Aug 10 '22 15:08 gabrielsantosrv

Hello @gabrielsantosrv

The function should now preserve the gradients on the master branch (we added a test to check so that it does not happens again.)

rflamary avatar Aug 18 '22 05:08 rflamary