POT icon indicating copy to clipboard operation
POT copied to clipboard

`sinkhorn2` and its `functorch.vmap` compatibility

Open hmdolatabadi opened this issue 1 year ago • 3 comments

🚀 Feature

Making the ot.sinkhorn2 function compatible with functorch.vmap.

Motivation

I'm using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:

for i in range(len(P_batch)):
      if i == 0:
         loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
      loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)

but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality.

losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)

But after wrapping my function in vmap, I get this weird error:

File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
    502 v = b / KtransposeU
    503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
    506         or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
    507         or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
    508     # we have reached the machine precision
    509     # come back to previous solution and quit loop
    510     warnings.warn('Warning: numerical errors at iteration %d' % ii)
    511     u = uprev

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

Pitch

Apparently, the data-dependent if-statement needs to be replaced with other alternatives. Any help is appreciated.

hmdolatabadi avatar Jun 01 '23 23:06 hmdolatabadi