指标计算方式
和作者确认一个事情:在计算F1和准召指标时,是不是只用了【实体词】完全匹配即可,没有考虑【实体词,开始位置,结束位置】三者完全匹配,代码位置在train.py 中的 validate(self, model, dev_loader)方法中
''' for text, logit, entity_result in zip(texts, logits, entity_results): p_results, p_results_detailed = self.data_manager.extract_entities(text, logit) for class_id, entity_set in entity_result.items(): p_entity_set = p_results.get(class_id) if p_entity_set is None: # 没预测出来 p_entity_set = set() # 预测出来并且正确个数 counts[class_id]['A'] += len(p_entity_set & entity_set) # 预测出来的结果个数 counts[class_id]['B'] += len(p_entity_set) # 真实的结果个数 counts[class_id]['C'] += len(entity_set) '''
和作者确认一个事情:在计算F1和准召指标时,是不是只用了【实体词】完全匹配即可,没有考虑【实体词,开始位置,结束位置】三者完全匹配,代码位置在train.py 中的 validate(self, model, dev_loader)方法中
''' for text, logit, entity_result in zip(texts, logits, entity_results): p_results, p_results_detailed = self.data_manager.extract_entities(text, logit) for class_id, entity_set in entity_result.items(): p_entity_set = p_results.get(class_id) if p_entity_set is None: # 没预测出来 p_entity_set = set() # 预测出来并且正确个数 counts[class_id]['A'] += len(p_entity_set & entity_set) # 预测出来的结果个数 counts[class_id]['B'] += len(p_entity_set) # 真实的结果个数 counts[class_id]['C'] += len(entity_set) '''
对的,没有考虑位置的信息。
和作者确认一个事情:在计算F1和准召指标时,是不是只用了【实体词】完全匹配即可,没有考虑【实体词,开始位置,结束位置】三者完全匹配,代码位置在train.py 中的 validate(self, model, dev_loader)方法中 ''' for text, logit, entity_result in zip(texts, logits, entity_results): p_results, p_results_detailed = self.data_manager.extract_entities(text, logit) for class_id, entity_set in entity_result.items(): p_entity_set = p_results.get(class_id) if p_entity_set is None: # 没预测出来 p_entity_set = set() # 预测出来并且正确个数 counts[class_id]['A'] += len(p_entity_set & entity_set) # 预测出来的结果个数 counts[class_id]['B'] += len(p_entity_set) # 真实的结果个数 counts[class_id]['C'] += len(entity_set) '''
对的,没有考虑位置的信息。
想请教一下是为什么没有考虑呢,如果一个数据样本中同一个实体词出现多次,有的预测出来了,有的没有预测出来,会不会无法反映模型真实的性能呢。
如果要考虑位置信息的话应该怎么修改呢,我的想法是: 1、对于训练数据,在 data.py 的 prepare_data(self, data) 方法中将 entity_results.setdefault(class_id, set()).add(entity['entity']) 修改成: entity_results.setdefault(class_id, set()).add((entity['entity'], start_idx, end_idx+1))
2、对于预测结果,在 data.py 的 extract_entities(self, text, model_output)方法中将 predict_results.setdefault(class_id, set()).add(entity_text) 修改成: predict_results.setdefault(class_id, set()).add((entity_text, start_in_text, end_in_text+1))