POT
POT copied to clipboard
batched ot.emd
🚀 Feature
To my understanding, the current implementation of ot.emd
takes only two probability distributions and a cost matrix. Is there any implementation of ot.emd
that takes in batched input that I am missing?
Motivation
As ot.emd
and ot.emd2
are capable of computing gradients, having a batched implementation of emd would help a lot to speed up training.
Is it possible to implement ot.emd
that takes batched input? If not can you please explain why?
Also, if it's not possible to give batched input, what is the best way to speed up the computation process (other than using regularized version)?
Thanks.
Actually as detailed in the function documentation ot.emd and emd2 can take GPU tensors but the solver in CPU bound so there is a memory copy overhead when on GPU. It is relatively small on large problems but can be quite limiting when calling several small problems often.
There is an openMP implementation of the solver that can benefit from multiple CPU cores or one can call in parallel the solvers on multiple problems in practice. But there is no way at the moment to do a batch exact OT solver since no network flow solver on GPU is available yet. For batch implementation regularized OT is indeed best.
Closing this one as a duplicate of #532 that is more detailed