fast-soft-sort
fast-soft-sort copied to clipboard
Understanding soft-sorting
Hello,
What would be the difference between the provided implementation (soft-sort) and the torch.sort version. Sorry for the stupid question, but I am not able to see how the torch sort non-differentiable is really different to the soft sort.
import torch.nn.functional as F
import torch
import torch.nn as nn
import pytorch_ops
import numpy as np
class Net(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(20, 20)
self.out = nn.Linear(20,20)
def forward(self, x):
out = F.relu(self.layer(x))
out = pytorch_ops.soft_sort(self.out(out)).float()
return out
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(20, 20)
self.out = nn.Linear(20,20)
def forward(self, x):
out = F.relu(self.layer(x))
out = torch.sort(self.out(out))[0]
return out
targets = torch.rand(32,20)
inputs = torch.rand(32,20)
net = Net()
net2 = Net2()
#try with soft-sort
loss = criterion(targets, net(inputs))
loss.backward()
#try with torch.sort
loss = criterion(targets, net2(inputs))
loss.backward()
#both work!
Hello,
The difference is the soft sort might result in easier optimization problems, see "convexification" in https://arxiv.org/pdf/2002.08871.pdf Intuitively, while the Jacobian of the normal sort operator is a permutation matrix, the soft sort operator can be denser (depending on the projection function).
Note that we also provide a smooth approximation to the rank function, which is not differentiable (as its output is integral).
thank you for your answer, but the torch.sort() is differentiable right? not easy to optimize, but differentiable nevertheless?
Yes, it is differentiable almost everywhere, with a Jacobian matrix that is a permutation matrix.