sst icon indicating copy to clipboard operation
sst copied to clipboard

Implementation of softtopk_np()

Open KatarinaYuan opened this issue 3 years ago • 1 comments

Hi, I'm wondering if you can share some insights about the choice of Numpy instead of PyTorch for the implementation of softtopk_np() in sst.core.topk.py? Because I think that I haven't found anything that cannot be implemented by PyTorch equivalently. Please correct me if I'm wrong about this. Many thanks.

KatarinaYuan avatar Dec 11 '21 14:12 KatarinaYuan

Following this issue cause I found the numpy implementation very inefficient... I tried to substitute the proposed numpy implementation with a torch based equivalent however the performances dropped a lot:

def softtopk_forward(logits, k):
    batchsize, n = logits.shape
    messages = -INF * torch.ones((batchsize, n, k + 1)).cuda()
    messages[:, 0, 0] = 0
    messages[:, 0, 1] = logits[:, 0]
    for i in range(1, n):
        for j in range(k + 1):
            logp_dont_use = messages[:, i - 1, j]
            if j > 0:
                logp_use = messages[:, i - 1, j - 1] + logits[:, i] 
            else:
                logp_use = -INF_T * torch.ones((batchsize,)).cuda()            
            message = torch.logaddexp(logp_dont_use, logp_use)
            messages[:, i, j] = message
    return messages

def softtopk_backward(logits, k):
    batchsize, n = logits.shape
    messages = -INF_T * torch.ones((batchsize, n, k + 1)).cuda()
    messages[:, n - 1, k] = 0
    for i in range(n - 2, -1, -1):
        for j in range(k + 1):
            logp_dont_use = messages[:, i + 1, j]
            if j < k:
                logp_use = messages[:, i + 1, j + 1] + logits[:, i + 1] 
            else:
                logp_use = -INF_T * torch.ones((batchsize,)).cuda()          
            message = torch.logaddexp(logp_dont_use, logp_use)
            messages[:, i, j] = message
    return messages

def softtopk_torch(logits, k):
    batchsize = logits.shape[0]
    f = softtopk_forward(logits, k)
    b = softtopk_backward(logits, k)
    initial_f = -INF_T * torch.ones((batchsize, 1, k + 1)).cuda()
    initial_f[:, :, 0] = 0
    ff = torch.cat([initial_f, f[:, :-1, :]], dim=1)
    lse0 = torch.logsumexp(ff + b, dim=2)
    lse1 = torch.logsumexp(ff[:, :, :-1] + b[:, :, 1:], dim=2) + logits
    return torch.exp(lse1 - torch.logaddexp(lse0, lse1))

class SoftTopK_torch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, k, eps):
        # ctx is a context object that can be used to stash information
        # for backward computation.
        ctx.save_for_backward(logits)
        ctx.k = k
        ctx.eps = eps
        mu = softtopk_torch(logits, k)
        return mu
    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        r"""http://www.cs.toronto.edu/~kswersky/wp-content/uploads/carbm.pdf"""
        logits, = ctx.saved_tensors
        k = ctx.k
        eps= ctx.eps
        n1 = softtopk_torch(logits + eps * grad_output, k)
        n2 = softtopk_torch(logits - eps * grad_output, k)
        grad = (n1 - n2) / (2 * eps)
        return grad, None, None

This implementation is highly inefficient even compared with the numpy one.

Emanuele97x avatar Jul 20 '22 10:07 Emanuele97x