fast-soft-sort icon indicating copy to clipboard operation
fast-soft-sort copied to clipboard

Understanding soft-sorting

Open vamp-ire-tap opened this issue 3 years ago • 3 comments

Hello,

What would be the difference between the provided implementation (soft-sort) and the torch.sort version. Sorry for the stupid question, but I am not able to see how the torch sort non-differentiable is really different to the soft sort.

import torch.nn.functional as F
import torch
import torch.nn as nn
import pytorch_ops
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(20, 20)
        self.out = nn.Linear(20,20)

    def forward(self, x):
        out = F.relu(self.layer(x))
        out = pytorch_ops.soft_sort(self.out(out)).float()
        return out

class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(20, 20)
        self.out = nn.Linear(20,20)

    def forward(self, x):
        out = F.relu(self.layer(x))
        out = torch.sort(self.out(out))[0]
        return out

targets = torch.rand(32,20)
inputs = torch.rand(32,20)
net = Net()
net2 = Net2()

#try with soft-sort
loss = criterion(targets, net(inputs))
loss.backward()

#try with torch.sort
loss = criterion(targets, net2(inputs))
loss.backward()

#both work!

vamp-ire-tap avatar Oct 20 '21 14:10 vamp-ire-tap

Hello,

The difference is the soft sort might result in easier optimization problems, see "convexification" in https://arxiv.org/pdf/2002.08871.pdf Intuitively, while the Jacobian of the normal sort operator is a permutation matrix, the soft sort operator can be denser (depending on the projection function).

Note that we also provide a smooth approximation to the rank function, which is not differentiable (as its output is integral).

josipd avatar Oct 21 '21 14:10 josipd

thank you for your answer, but the torch.sort() is differentiable right? not easy to optimize, but differentiable nevertheless?

vamp-ire-tap avatar Oct 22 '21 12:10 vamp-ire-tap

Yes, it is differentiable almost everywhere, with a Jacobian matrix that is a permutation matrix.

josipd avatar Oct 26 '21 15:10 josipd