POT
POT copied to clipboard
`sinkhorn2` and its `functorch.vmap` compatibility
🚀 Feature
Making the ot.sinkhorn2
function compatible with functorch.vmap
.
Motivation
I'm using the Python Optimal Transport
library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn
distance for that sample and its ground-truth value. What I was using before was a for-loop:
for i in range(len(P_batch)):
if i == 0:
loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
but this is way too slow for my application. I was reading through functorch
, and apparently I should have been able to use the vmap
functionality.
losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)
But after wrapping my function in vmap
, I get this weird error:
File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
502 v = b / KtransposeU
503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
506 or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
507 or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
508 # we have reached the machine precision
509 # come back to previous solution and quit loop
510 warnings.warn('Warning: numerical errors at iteration %d' % ii)
511 u = uprev
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
Pitch
Apparently, the data-dependent if-statement
needs to be replaced with other alternatives. Any help is appreciated.