LongAlign
LongAlign copied to clipboard
关于Packing和 直接Batch的loss区别?
论文中指出Packing Loss和直接Batch Loss不一致,是基于这个公式: 即:以样本为粒度,算loss 先在样本内平均,再batch内平均,两步走。
基于我的认知,SFT训练中一般是以Token为粒度算最终的loss的,即 "target token loss 总和 / target token 总数",并非样本粒度。
我看了下你的代码实现,即modeling_llama.py文件中按直接Batch算,loss是 从 batch*seq 直接Flat成一个seq,还是直接以token为粒度计算的loss,并非样本粒度(即先在seq 求平均,再在batch求平均)
有两个问题讨论:
- SFT中loss 最后一步的平均, 究竟应该以Token为粒度 还是以样本为粒度?
- 如果以Token为粒度,我认为Packing和非Packing是等价的