insightface
insightface copied to clipboard
why ms1mv3 datasets use arcface.gtlink360 and wf42m use magrin loss???
class CombinedMarginLoss(torch.nn.Module): def init(self, s, m1, m2, m3, interclass_filtering_threshold=0): super().init() self.s = s self.m1 = m1 self.m2 = m2 self.m3 = m3 self.interclass_filtering_threshold = interclass_filtering_threshold
# For ArcFace
self.cos_m = math.cos(self.m2)
self.sin_m = math.sin(self.m2)
self.theta = math.cos(math.pi - self.m2)
self.sinmm = math.sin(math.pi - self.m2) * self.m2
self.easy_margin = False
def forward(self, logits, labels):
index_positive = torch.where(labels != -1)[0]
if self.interclass_filtering_threshold > 0:
with torch.no_grad():
dirty = logits > self.interclass_filtering_threshold
dirty = dirty.float()
mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
mask.scatter_(1, labels[index_positive], 0)
dirty[index_positive] *= mask
tensor_mul = 1 - dirty
logits = tensor_mul * logits
target_logit = logits[index_positive, labels[index_positive].view(-1)]
if self.m1 == 1.0 and self.m3 == 0.0:
with torch.no_grad():
target_logit.arccos_()
logits.arccos_()
final_target_logit = target_logit + self.m2
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
logits.cos_()
logits = logits * self.s
elif self.m3 > 0:
final_target_logit = target_logit - self.m3
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
logits = logits * self.s
else:
raise
return logits
in so many configs, only ms1mv3_**** , config.margin_list = (1.0, 0.5, 0.0) is arcface, others are config.margin_list = (1.0, 0.0, 0.4) is magrin loss??? Is the arcface not suitable for large datasets??