differentiable_sorting
differentiable_sorting copied to clipboard
Sorting 3D or 4D vectors
I love this library and got it working well with pytorch:
# 2D tensor
input_tensor = torch.tensor([
[1, 5],
[30, 30],
[6, 9],
[80, -2],
]).float()
mask = torch.tensor([1, 0]).float()
vector_sort(
bitonic_matrices(4),
input_tensor,
lambda x: x @ mask, # sort by column 1
alpha=1.0
)
> tensor([
[80.0000, -2.0000],
[30.0000, 30.0000],
[ 5.9665, 8.9732],
[ 1.0335, 5.0268]
])
But I am now trying to to extend this to higher dimensions:
# 3D tensor
input_tensor = torch.tensor([
[
[1, 5],
[30, 30],
[6, 9],
[80, -2]
],
[
[2, 6],
[31, 31],
[7, 10],
[81, -1]
],
]).float()
target_tensor = torch.tensor([
[
[80, -2],
[30, 30],
[6, 9],
[1, 5],
],
[
[81, -1],
[31, 31],
[7, 10],
[2, 6],
],
])
mask = torch.tensor([1, 0]).float()
vector_sort(
bitonic_matrices(8),
input_tensor,
lambda x: x @ mask,
alpha=1.0
)
But I receive the error:
~\anaconda3\lib\site-packages\differentiable_sorting\torch\differentiable_sorting_torch.py in vector_sort(matrices, X, key, alpha)
85 x = key(X)
86 # compute weighting on the scalar function
---> 87 a, b = l @ x, r @ x
88 a_weight = torch.exp(a * alpha) / (torch.exp(a * alpha) + torch.exp(b * alpha))
89 b_weight = 1 - a_weight
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x8 and 2x4)
How would I go about extending the sorting to work for 3D or nD tensors?
This is probably too late to be useful (sorry!), but you'd need to either: (a) Unravel your data, sort it, and then reshape it back into the nD tensor, if what you want is just to sort the elements independently. (b) Define a (differentiable) comparator function if you want to use the tensor structure in the sorting, and then call comparison_sort(matrices, my_comparator) (e.g. you could sort the matrices by sum of rows using this method).
Alternatively,
(c) if your sorting doesn't require a fully custom comparator, but you can instead map from some space (e.g. row vectors) to scalars (as in the row sum example), you could use vector_sort() with a key function which maps the input through the key and sorts on that.