LongAlign
LongAlign copied to clipboard
packing loss 的归一化问题
这里的loss计算是不是应该归一化一下
loss = (loss * shift_weights).sum() -> loss = (loss * shift_weights).sum() / shift_weights.sum()
把loss归一化到token粒度 前一种方式,loss的scale偏大,而且反向传播梯度也会偏大。而且极限情况下,假设每个样本只有1个token,这个batch的loss会爆炸
这里的shift_weights已经经过归一化了。每个sample的weight加起来为1。