FlagEmbedding
FlagEmbedding copied to clipboard
bug:多卡同步的时候,cross_targets计算方式是不是有问题?
代码第265行,多卡数据同步之后,cross_targets计算方式有问题,应该得考虑当前local rank。 https://github.com/FlagOpen/FlagEmbedding/blob/97f57a1b92dc68d56731a1e38a2d3aad4cd67e20/FlagEmbedding/BGE_M3/modeling.py#L265
原始是:cross_targets = idxs_cross * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0)) 应该是:cross_targets = idxs_cross * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0))+self.process_rank*p_dense_vecs.size(0)
您好,应该是没有问题的。多卡数据同步后,分数依然是一个矩阵,第i个query的pos在第i*group-size个。