FlagEmbedding
FlagEmbedding copied to clipboard
【code share】在微调BGE时增加 evaluation data 监测验证集指标
官方代码在训练时没有添加验证集指标,不太容易监测是否过拟合。经过尝试,增加compute_metrics也不行,Trainer的evaluate逻辑有点复杂走不到这里,最终还是得重构一下evaluate。下面分享一个很简单的重构供参考,训练过程中返回验证集的损失,只需正常添加do_eval、eval_steps、evaluation_strategy等参数就像。可以根据自己的需求完善验证的逻辑。
trainer.py里class BiTrainer(Trainer):
def evaluate(
self,
test_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
losses = []
for inputs in tqdm(self.eval_dataset, desc='evaluate'):
inputs = self.data_collator([inputs])
inputs['query']['input_ids'] = inputs['query']['input_ids'].to('npu')
inputs['query']['attention_mask'] = inputs['query']['attention_mask'].to('npu')
inputs['passage']['input_ids'] = inputs['passage']['input_ids'].to('npu')
inputs['passage']['attention_mask'] = inputs['passage']['attention_mask'].to('npu')
inputs.pop('teacher_scores')
inputs.pop('bi_directions')
loss = self.compute_loss(self.model, inputs)
loss = loss.mean().detach().item()
losses.append(loss)
metrics = {'eval_loss': sum(losses) / len(losses)}
self.log(metrics)
return metrics
官方代码在训练时没有添加验证集指标,不太容易监测是否过拟合。经过尝试,增加
compute_metrics也不行,Trainer的evaluate逻辑有点复杂走不到这里,最终还是得重构一下evaluate。下面分享一个很简单的重构供参考,训练过程中返回验证集的损失,只需正常添加do_eval、eval_steps、evaluation_strategy等参数就像。可以根据自己的需求完善验证的逻辑。
trainer.py里class BiTrainer(Trainer):def evaluate( self, test_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: losses = [] for inputs in tqdm(self.eval_dataset, desc='evaluate'): inputs = self.data_collator([inputs]) inputs['query']['input_ids'] = inputs['query']['input_ids'].to('npu') inputs['query']['attention_mask'] = inputs['query']['attention_mask'].to('npu') inputs['passage']['input_ids'] = inputs['passage']['input_ids'].to('npu') inputs['passage']['attention_mask'] = inputs['passage']['attention_mask'].to('npu') inputs.pop('teacher_scores') inputs.pop('bi_directions') loss = self.compute_loss(self.model, inputs) loss = loss.mean().detach().item() losses.append(loss) metrics = {'eval_loss': sum(losses) / len(losses)} self.log(metrics) return metrics
您好,想请问微调bge-m3的显存消耗是多少?