LongAlign icon indicating copy to clipboard operation
LongAlign copied to clipboard

关于Packing和 直接Batch的loss区别?

Open BitVoyage opened this issue 4 months ago • 3 comments

论文中指出Packing Loss和直接Batch Loss不一致,是基于这个公式: image 即:以样本为粒度,算loss 先在样本内平均,再batch内平均,两步走。

基于我的认知,SFT训练中一般是以Token为粒度算最终的loss的,即 "target token loss 总和 / target token 总数",并非样本粒度。

我看了下你的代码实现,即modeling_llama.py文件中按直接Batch算,loss是 从 batch*seq 直接Flat成一个seq,还是直接以token为粒度计算的loss,并非样本粒度(即先在seq 求平均,再在batch求平均) image

有两个问题讨论:

  1. SFT中loss 最后一步的平均, 究竟应该以Token为粒度 还是以样本为粒度?
  2. 如果以Token为粒度,我认为Packing和非Packing是等价的

BitVoyage avatar Feb 18 '24 12:02 BitVoyage