POT icon indicating copy to clipboard operation
POT copied to clipboard

batched ot.emd

Open siddharthKatageri opened this issue 1 year ago • 1 comments

🚀 Feature

To my understanding, the current implementation of ot.emd takes only two probability distributions and a cost matrix. Is there any implementation of ot.emd that takes in batched input that I am missing?

Motivation

As ot.emd and ot.emd2 are capable of computing gradients, having a batched implementation of emd would help a lot to speed up training.

Is it possible to implement ot.emd that takes batched input? If not can you please explain why? Also, if it's not possible to give batched input, what is the best way to speed up the computation process (other than using regularized version)?

Thanks.

siddharthKatageri avatar Aug 02 '22 18:08 siddharthKatageri

Actually as detailed in the function documentation ot.emd and emd2 can take GPU tensors but the solver in CPU bound so there is a memory copy overhead when on GPU. It is relatively small on large problems but can be quite limiting when calling several small problems often.

There is an openMP implementation of the solver that can benefit from multiple CPU cores or one can call in parallel the solvers on multiple problems in practice. But there is no way at the moment to do a batch exact OT solver since no network flow solver on GPU is available yet. For batch implementation regularized OT is indeed best.

rflamary avatar Aug 08 '22 06:08 rflamary

Closing this one as a duplicate of #532 that is more detailed

rflamary avatar Mar 01 '24 07:03 rflamary