FlagEmbedding icon indicating copy to clipboard operation
FlagEmbedding copied to clipboard

BUG: bge-m3 no_in_batch_neg 数据微调,计算ensemble_scores时get_local_score方法存在数组越界

Open KeepGoingCSU opened this issue 9 months ago • 1 comments

#运行脚本 官方example:https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/finetune/embedder/encoder_only/m3_same_dataset.sh 上述脚本,修改per_device_train_batch_size=4(原始是2不会报错)

报错代码:

File "/mnt/bn/rc-tob-lq/users/huangrong.max/FlagEmbedding/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py", line 426, in forward
[rank1]:     ensemble_scores, ensemble_loss = compute_loss_func(
[rank1]:   File "/mnt/bn/rc-tob-lq/users/huangrong.max/FlagEmbedding/FlagEmbedding/abc/finetune/embedder/AbsModeling.py", line 149, in _compute_no_in_batch_neg_loss
[rank1]:     local_scores = self.compute_local_score(q_reps, p_reps, compute_score_func, **kwargs)   # (batch_size, group_size)
[rank1]:   File "/mnt/bn/rc-tob-lq/users/huangrong.max/FlagEmbedding/FlagEmbedding/abc/finetune/embedder/AbsModeling.py", line 140, in compute_local_score
[rank1]:     loacl_scores = self.get_local_score(q_reps, p_reps, all_scores)
[rank1]:   File "/mnt/bn/rc-tob-lq/users/huangrong.max/FlagEmbedding/FlagEmbedding/abc/finetune/embedder/AbsModeling.py", line 117, in get_local_score
[rank1]:     all_scores[torch.arange(q_reps.size(0), device=q_reps.device), indices + i]

分析归因:

参数设置

当输入数据为:no_in_batch_neg per_device_train_batch_size=4,train_group_size 8 分析:

  1. https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py#L426 计算ensemble_scores时,传入的dense_scores、sparse_scores、colbert_scores在计算时都调用的_compute_no_in_batch_neg_loss,所以返回维度均为:[2, 8]
  2. https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/abc/finetune/embedder/AbsModeling.py#L117 在1计算中调用get_local_score(self, q_reps, p_reps, all_scores),透传的三个tensor shape为:[2, 1024]、[16, 1024]、[2, 8],所以在第117行会出现all_scores取值数组越界情况

@hanhainebula 辛苦看看

KeepGoingCSU avatar Mar 25 '25 10:03 KeepGoingCSU

你好, @KeepGoingCSU!非常感谢你指出这个 bug,我已经在 PR #1424 中修复了此问题。

hanhainebula avatar Apr 10 '25 12:04 hanhainebula