RRHF icon indicating copy to clipboard operation
RRHF copied to clipboard

算loss的时候求均值的时候是不是可以优化

Open shyoulala opened this issue 6 months ago • 6 comments

我看到在sft_loss 的时候直接求了平均,平均的分母是样本label的长度,包括不参与训练的,是否应该采用mask mean 就像: item = -logit_label[max_idx] return -torch.sum(item)/ torch.sum(labels!=-100)。##因为在gather_logits_labels 这一步把-100的prob已经变成0了 而不是-logit_label[max_idx].mean()

image

shyoulala avatar Dec 21 '23 12:12 shyoulala