geomloss icon indicating copy to clipboard operation
geomloss copied to clipboard

[Question] The most efficient way to calculate OTLoss between the rows of a 3D-tensor?

Open aGIToz opened this issue 3 years ago • 0 comments

Hello,

Thank you for this wonderful project!

I am wondering on something, hope someone can help.

Consider a toy example with a small size.

a = torch.randn(10,8,5)

I am looking for the most efficient way to create a distance matrix D of size (10,5) such that :

D[i,j] = OTLoss(a[i,:,0], a[i,:,j])

Evidently D[i,0] = 0 for all the is Here i runs from 0 to 9 and j runs from 0 to 4, thus D has the shape of (10,5).

Now assuming that a is a very large 3D-tensor (let's say (100k,100,100)), what is the most efficient way to get this D tensor? I am hoping that there exist some solution without using loops.

aGIToz avatar Apr 29 '21 15:04 aGIToz