geomloss
geomloss copied to clipboard
Batch support of SampleLoss
First of all thanks for the great library!
I just tried to run SampleLoss with batches of data and it did not work.
So, I have two tensors x,y of the same shape [batch_dim, n_points, feature_dim] and wish to compute the sinkhorn divergence between the point clouds x[0] and y[0], x[1] and y[1] in a batched way (to prevent slow loops) in order to return a tensor of shape [batch_dim].
However, when trying this out with SampleLoss() I receive a shape error.
To reproduce a minimal example I add the following collab here: https://colab.research.google.com/drive/1NqagWVIv-a8YN258NcFEBXbRBFAVuuiR?usp=sharing
You have to install GeomLoss Github version using the following instruction :
pip install git+https://github.com/jeanfeydy/geomloss
This problem has been fixed in a commit more recent than last release.