FlagEmbedding
FlagEmbedding copied to clipboard
bge-m3统一微调(密集嵌入、稀疏嵌入和colbert)的原理
在进行对bge-m3统一微调(密集嵌入、稀疏嵌入和colbert)的时候,发现训练的代码不是很详细,不太清楚其中的原理
{"query": str, "pos": List[str], "neg":List[str]}
是query+pos,query+neg,进行二分类么
https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py#L302 交叉熵损失,对pos和neg计算分数,将正样本作为正确分类计算损失
https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py#L302 交叉熵损失,对pos和neg计算分数,将正样本作为正确分类计算损失
我debug了一下代码,请问可以这么理解么,计算query和所有passage的分数,然后选取最大分数的那个passage下标作为预测值,实际pos的位置为label,然后计算交叉熵损失