conditional-flow-matching icon indicating copy to clipboard operation
conditional-flow-matching copied to clipboard

Optimal transport with torch + GPU?

Open spinjo opened this issue 1 year ago • 1 comments

In the current implementation, tensors are moved to numpy + CPU before calling the optimal transport solver, see e.g. https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py#L88.

Since version 0.8, the POT package supports backends beyond numpy, and GPU acceleration, see https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends-on-cpu-gpu. This can speed up the OT solver especially for large batchsize, and enables new features like differentiation through the solver. Is there a reason why torchcfm uses the numpy + cpu policy?

I am successfully using the torch + GPU support of POT and am happy to file a PR if there is interest in including this in torchcfm.

spinjo avatar Sep 28 '24 11:09 spinjo