FastBERT
FastBERT copied to clipboard
FastBERTClassifier类实现未使用attention mask
您好,分类层forward里面没有使用attention mask,会导致padding影响结果。您可以验证一下。
分类层不使用attention的话,训练时候还是会attention到padding,推理时候padding和不padding结果有些微差别,不过对准确率影响不是特别大,模型也会学习到padding的embedding是没用的。
是的,代码里分类层没加mask是个bug,我修复一下,谢谢~
是的,代码里分类层没加mask是个bug,我修复一下,谢谢~
好的好的
还有个问题,就是批处理推理的代码是跑不通的。我修改了一下,亲测可行。
if inference:
batch_size = hidden_states.shape[0]
nb_class = int(self.num_class.item())
nb_classifiers = len(self.layer_classifiers)
device = hidden_states.device
# positions will keep track of the original position of each element in the
# batch when elements will be removed
layer_idxes = torch.zeros(batch_size, device=device, dtype=torch.int)
final_probs = torch.zeros((batch_size, nb_class), device=device)
uncertain_infos = torch.zeros((batch_size, nb_classifiers), device=device)
positions = torch.arange(start=0, end=batch_size, device=device, dtype=torch.long)
for i, (layer_module, (k, layer_classifier_module)) in enumerate(zip(self.layers, self.layer_classifiers.items())):
hidden_states = layer_module(hidden_states, attention_mask)
logits = layer_classifier_module(hidden_states, attention_mask)
prob = F.softmax(logits, dim=-1)
log_prob = F.log_softmax(logits, dim=-1)
uncertain = torch.sum(prob * log_prob, 1) / (-torch.log(self.num_class))
uncertain_infos[positions, i] = uncertain
enough_info = uncertain < inference_speed
if i != nb_classifiers - 1:
if torch.any(enough_info): # 存在满足复杂度条件的样本
certain_positions = positions[enough_info]
final_probs[certain_positions] = prob[enough_info]
layer_idxes[certain_positions] = i
hidden_states = hidden_states[~enough_info]
attention_mask = attention_mask[~enough_info]
# if we have processed all the samples
if hidden_states.shape[0] == 0:
return final_probs, layer_idxes, uncertain_infos
positions = positions[~enough_info] # updating the positions to fit the new batch
else: # final classifier
final_probs[positions] = prob
layer_idxes[positions] = i
return final_probs, layer_idxes, uncertain_infos
infer.py也要做相应的修改:
def infer_model(master_gpu_id, model, dataset, batch_size,
use_cuda=False, num_workers=1, inference_speed=None, dump_info_file=None):
global global_step
global debug_break
model.eval()
infer_dataloader = data.DataLoader(dataset=dataset,
collate_fn=TextCollate(dataset),
pin_memory=use_cuda,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False)
correct_sum = 0
num_sample = infer_dataloader.dataset.__len__()
predicted_probs = []
true_labels = []
infos = []
logging.info("Inference Model...")
stime_all = time.time()
for step, batch in enumerate(tqdm(infer_dataloader, unit="batch", ncols=100, desc="Inference process: ")):
texts = batch["texts"]
tokens = batch["tokens"].cuda(master_gpu_id) if use_cuda else batch["tokens"]
segment_ids = batch["segment_ids"].cuda(master_gpu_id) if use_cuda else batch["segment_ids"]
attn_masks = batch["attn_masks"].cuda(master_gpu_id) if use_cuda else batch["attn_masks"]
labels = batch["labels"].cuda(master_gpu_id) if use_cuda else batch["labels"]
with torch.no_grad():
probs, layer_idxes, uncertain_infos = model(tokens, token_type_ids=segment_ids, attention_mask=attn_masks,
inference=True, inference_speed=inference_speed)
_, top_index = probs.topk(1)
correct_sum += (top_index.view(-1) == labels).sum().item()
if dump_info_file != None:
for label, pred, prob, layer_idx, text in zip(labels, top_index.view(-1), probs, layer_idxes, texts):
infos.append((label.item(), pred.item(), prob.cpu().numpy(), layer_idx.item(), text))
if debug_break and step > 50:
break
time_per = (time.time() - stime_all)/num_sample
time_all = time.time() - stime_all
acc = format(correct_sum / num_sample, "0.4f")
logging.info("speed_arg:%s, time_per_record:%s, acc:%s, total_time:%s",
inference_speed, format(time_per, '0.4f'), acc, format(time_all, '0.4f'))
if dump_info_file != None and len(dump_info_file) != 0:
with open(dump_info_file, 'w') as fw:
for label, pred, prob, layer_i, text in infos:
fw.write('\t'.join([str(label), str(pred), str(layer_i), text])+'\n')
if probs.shape[1] == 2:
labels_pr = [info[0] for info in infos]
preds_pr = [info[1] for info in infos]
precise, recall = eval_pr(labels_pr, preds_pr)
logging.info("precise:%s, recall:%s", format(precise, '0.4f'), format(recall, '0.4f'))
又发现个新bug...在utils.py的init_adam_optimizer函数里,optimizer_parameters初始化有问题,no_decay的参数永远为0
改成:
optimizer_parameters = [
{"params": [p for param_name, p in model.named_parameters() if not any(name in param_name for name in no_decay)], "weight_decay_rate": 0.01},
{"params": [p for param_name, p in model.named_parameters() if any(name in param_name for name in no_decay)], "weight_decay_rate": 0.0}
]