geomloss icon indicating copy to clipboard operation
geomloss copied to clipboard

Batch support of SampleLoss

Open mi92 opened this issue 4 years ago • 1 comments

First of all thanks for the great library!

I just tried to run SampleLoss with batches of data and it did not work.

So, I have two tensors x,y of the same shape [batch_dim, n_points, feature_dim] and wish to compute the sinkhorn divergence between the point clouds x[0] and y[0], x[1] and y[1] in a batched way (to prevent slow loops) in order to return a tensor of shape [batch_dim].

However, when trying this out with SampleLoss() I receive a shape error.

To reproduce a minimal example I add the following collab here: https://colab.research.google.com/drive/1NqagWVIv-a8YN258NcFEBXbRBFAVuuiR?usp=sharing

mi92 avatar Oct 01 '20 08:10 mi92

You have to install GeomLoss Github version using the following instruction : pip install git+https://github.com/jeanfeydy/geomloss

This problem has been fixed in a commit more recent than last release.

NightWinkle avatar Oct 01 '20 11:10 NightWinkle