geomloss
geomloss copied to clipboard
[Question] The most efficient way to calculate OTLoss between the rows of a 3D-tensor?
Hello,
Thank you for this wonderful project!
I am wondering on something, hope someone can help.
Consider a toy example with a small size.
a = torch.randn(10,8,5)
I am looking for the most efficient way to create a distance matrix D
of size (10,5)
such that :
D[i,j] = OTLoss(a[i,:,0], a[i,:,j])
Evidently D[i,0] = 0
for all the i
s
Here i
runs from 0 to 9
and j
runs from 0 to 4
, thus D
has the shape of (10,5)
.
Now assuming that a
is a very large 3D-tensor (let's say (100k,100,100)
), what is the most efficient way to get this D
tensor?
I am hoping that there exist some solution without using loops.