sst
sst copied to clipboard
Implementation of softtopk_np()
Hi, I'm wondering if you can share some insights about the choice of Numpy instead of PyTorch for the implementation of softtopk_np() in sst.core.topk.py? Because I think that I haven't found anything that cannot be implemented by PyTorch equivalently. Please correct me if I'm wrong about this. Many thanks.
Following this issue cause I found the numpy implementation very inefficient... I tried to substitute the proposed numpy implementation with a torch based equivalent however the performances dropped a lot:
def softtopk_forward(logits, k):
batchsize, n = logits.shape
messages = -INF * torch.ones((batchsize, n, k + 1)).cuda()
messages[:, 0, 0] = 0
messages[:, 0, 1] = logits[:, 0]
for i in range(1, n):
for j in range(k + 1):
logp_dont_use = messages[:, i - 1, j]
if j > 0:
logp_use = messages[:, i - 1, j - 1] + logits[:, i]
else:
logp_use = -INF_T * torch.ones((batchsize,)).cuda()
message = torch.logaddexp(logp_dont_use, logp_use)
messages[:, i, j] = message
return messages
def softtopk_backward(logits, k):
batchsize, n = logits.shape
messages = -INF_T * torch.ones((batchsize, n, k + 1)).cuda()
messages[:, n - 1, k] = 0
for i in range(n - 2, -1, -1):
for j in range(k + 1):
logp_dont_use = messages[:, i + 1, j]
if j < k:
logp_use = messages[:, i + 1, j + 1] + logits[:, i + 1]
else:
logp_use = -INF_T * torch.ones((batchsize,)).cuda()
message = torch.logaddexp(logp_dont_use, logp_use)
messages[:, i, j] = message
return messages
def softtopk_torch(logits, k):
batchsize = logits.shape[0]
f = softtopk_forward(logits, k)
b = softtopk_backward(logits, k)
initial_f = -INF_T * torch.ones((batchsize, 1, k + 1)).cuda()
initial_f[:, :, 0] = 0
ff = torch.cat([initial_f, f[:, :-1, :]], dim=1)
lse0 = torch.logsumexp(ff + b, dim=2)
lse1 = torch.logsumexp(ff[:, :, :-1] + b[:, :, 1:], dim=2) + logits
return torch.exp(lse1 - torch.logaddexp(lse0, lse1))
class SoftTopK_torch(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, k, eps):
# ctx is a context object that can be used to stash information
# for backward computation.
ctx.save_for_backward(logits)
ctx.k = k
ctx.eps = eps
mu = softtopk_torch(logits, k)
return mu
@staticmethod
def backward(ctx, grad_output):
# We return as many input gradients as there were arguments.
# Gradients of non-Tensor arguments to forward must be None.
r"""http://www.cs.toronto.edu/~kswersky/wp-content/uploads/carbm.pdf"""
logits, = ctx.saved_tensors
k = ctx.k
eps= ctx.eps
n1 = softtopk_torch(logits + eps * grad_output, k)
n2 = softtopk_torch(logits - eps * grad_output, k)
grad = (n1 - n2) / (2 * eps)
return grad, None, None
This implementation is highly inefficient even compared with the numpy one.