mdistiller
mdistiller copied to clipboard
关于mask的一些问题
尊敬的作者,您好,请教您一个问题: 在复现kdk的代码中,我发现了关于mask的三个函数
def _get_gt_mask(logits, target):
target = target.reshape(-1)
mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
return mask
def _get_other_mask(logits, target):
target = target.reshape(-1)
mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
return mask
def cat_mask(t, mask1, mask2):
t1 = (t * mask1).sum(dim=1, keepdims=True)
t2 = (t * mask2).sum(1, keepdims=True)
rt = torch.cat([t1, t2], dim=1)
return rt
而这三个函数其实在paper中的伪代码中是没有的,请问这三个函数有何特殊意义呢?
这是dkd的一种更快速的实现,具体可以参考#issue1 中的解答。之所以这样实现是因为我们发现在torch的框架下,使用index的方式去读取non-target类的logits是非常耗时的,而转而采取mask相乘的机制去表示DKD,两种实现方式的性能是没有任何区别的。