pykt-toolkit
pykt-toolkit copied to clipboard
sparseKT-ktop算法存在的问题
您好,感谢您提供的代码,让我有机会能够进行学习。 我在读sparseKT代码的时候发现了ktop在实现上的一些问题。 q1:在attention方法中,scores = F.softmax(scores, dim=-1) 后需要需要乘上scores = scores * mask.float().to(device)?因为在注意力权重中的第一行需要全部经过掩码操作,经过softmax后,实际的结果并不是0,而是一个接近0的极小数。 q2:ktop算法的具体实现。代码将scores分解成了scores_a和scores_b。原scores中第一行的值为全0,第二行的值应该为1+199个0。代码实现是在将scores_a和scores_b重新拼接后再进行softmax操作,这样会导致scores_a中的所有行需要重新分配注意力,原本因为掩码作用为0的位置又重新获得了注意力权重,这就导致了偷看到了未来位置。修改方法如下:是否应该先对scores_b进行softmax操作,然后再将scores_b和scores_a进行拼接。 修改后: scores_a = scores[:, :, :k_index, :] scores_b = scores[:, :, k_index:, :].reshape(bshead(seqlen-k_index), -1) sorted_scores,sorted_idx = torch.sort(scores_b,descending=True) scores_t = sorted_scores[:,k_index-1:k_index].repeat(1,seqlen) scores_b = torch.where(scores_b - scores_t >= torch.tensor(0), scores_b, torch.tensor(-1e32)).reshape(bs,head,seqlen-k_index,-1) scores_b = F.softmax(scores_b, dim=-1) # BS,8,seqlen,seqlen scores = torch.cat([scores_a, scores_b], dim=2) 由于本人刚刚入门,担心作者的实现逻辑顺序有其他用意,理解不对的地方希望能够获得指正,期待您的回复!