FlagEmbedding icon indicating copy to clipboard operation
FlagEmbedding copied to clipboard

Update negative sampling in hn_mine to fix issue #464

Open shtdbb opened this issue 1 year ago • 1 comments

关于解决 #464 的修改。 避免困难样本挖掘时,当召回负样本数量少于预设负采样数量,会随机采样到正样本、或重复采样负样本的问题。 修改为,默认从 corpus 中剔除正例和已召回的负例,再进行随机采样;若剔除后 corpus 为空,说明需要重复采样负样本才能满足负采样数量要求,则只剔除正样本、重复采样负样本即可。

shtdbb avatar Feb 20 '24 08:02 shtdbb

感谢您的PR!但是目前的操作看起来比较复杂,可能会导致比较大的时间消耗。这块可能还需要再好好考虑一下。

staoxiao avatar Feb 21 '24 06:02 staoxiao

感谢您的PR!但是目前的操作看起来比较复杂,可能会导致比较大的时间消耗。这块可能还需要再好好考虑一下。

@staoxiao 您好,我这边修改了一下策略。

方法

我是选择多随机采样一个负样本用于备用,然后采样后查看正样本是否被采样,若正样本在则用备份样本代替,否则直接舍弃备用样本即可。这样每个样本采样后再进行过滤,只需要遍历负样本数量的列表即可。

测试

我这边使用数据集大致测试了一下:使用 10 万个样本的数据集进行负样本挖掘,设置脚本参数 python -m hn_mine --model_name_or_path models/bge-large-zh-v1.5 --input_file dataset.jsonl --range_for_sampling 2-200 --negative_number 100 --output_file dataset_neg.jsonl,设置采样 100 个负样本测试极端情况。计时则单独计算 87-98 行这个 for 循环的运算时间:https://github.com/FlagOpen/FlagEmbedding/blob/db672e177083f25cbf44fe2d5d0ac91f02380726/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py#L87-L98

结果

  • 不做正样本过滤处理:约 8.7s
  • 正样本的后过滤处理:约 9.3s 从结果上来看,从 10 万个样本中负采样 100 个样本,应该能满足大部分微调的负采样需求,故做正样本的后过滤大致的时间花费个人认为是可以接受的~

shtdbb avatar Apr 24 '24 08:04 shtdbb

@shtdbb , 非常感谢您的PR! 有个小问题,data['pos']是一个列表,可能包含多个正样本,无法执行sent != data['pos']。 如果您跑通了这个代码,需要检查数据格式是否正确。data['pos']如果是一个字符串的话,训练会有很大问题(代码将随机选取一个字母作为pos)。

代码建议改为这样:

samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[:negative_number - len(data['neg'])])

staoxiao avatar Apr 24 '24 09:04 staoxiao

@shtdbb , 非常感谢您的PR! 有个小问题,data['pos']是一个列表,可能包含多个正样本,无法执行sent != data['pos']。 如果您跑通了这个代码,需要检查数据格式是否正确。data['pos']如果是一个字符串的话,训练会有很大问题(代码将随机选取一个字母作为pos)。

代码建议改为这样:

samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[:negative_number - len(data['neg'])])

非常感谢您的耐心!很不好意思,我忽略了data['pos']类型是list了。感谢提醒!已按照您的建议修改~

shtdbb avatar Apr 24 '24 13:04 shtdbb

thanks~

staoxiao avatar Apr 24 '24 13:04 staoxiao