FastBERT icon indicating copy to clipboard operation
FastBERT copied to clipboard

FastBERTClassifier类实现未使用attention mask

Open CaoYiwei opened this issue 2 years ago • 5 comments

您好,分类层forward里面没有使用attention mask,会导致padding影响结果。您可以验证一下。

CaoYiwei avatar Dec 21 '21 06:12 CaoYiwei


CaoYiwei avatar Dec 21 '21 06:12 CaoYiwei


BitVoyage avatar Dec 27 '21 07:12 BitVoyage



CaoYiwei avatar Jan 06 '22 08:01 CaoYiwei


        if inference:
            batch_size = hidden_states.shape[0]
            nb_class = int(self.num_class.item())
            nb_classifiers = len(self.layer_classifiers)
            device = hidden_states.device
            # positions will keep track of the original position of each element in the
            # batch when elements will be removed
            layer_idxes = torch.zeros(batch_size, device=device,
            final_probs = torch.zeros((batch_size, nb_class), device=device)
            uncertain_infos = torch.zeros((batch_size, nb_classifiers), device=device)
            positions = torch.arange(start=0, end=batch_size, device=device, dtype=torch.long)

            for i, (layer_module, (k, layer_classifier_module)) in enumerate(zip(self.layers, self.layer_classifiers.items())):
                hidden_states = layer_module(hidden_states, attention_mask)
                logits = layer_classifier_module(hidden_states, attention_mask)
                prob = F.softmax(logits, dim=-1)
                log_prob = F.log_softmax(logits, dim=-1)
                uncertain = torch.sum(prob * log_prob, 1) / (-torch.log(self.num_class))
                uncertain_infos[positions, i] = uncertain
                enough_info = uncertain < inference_speed
                if i != nb_classifiers - 1:
                    if torch.any(enough_info):  # 存在满足复杂度条件的样本
                        certain_positions = positions[enough_info]
                        final_probs[certain_positions] = prob[enough_info]
                        layer_idxes[certain_positions] = i
                        hidden_states = hidden_states[~enough_info]
                        attention_mask = attention_mask[~enough_info]

                        # if we have processed all the samples
                        if hidden_states.shape[0] == 0:
                            return final_probs, layer_idxes, uncertain_infos

                        positions = positions[~enough_info]  # updating the positions to fit the new batch                      
                else:  # final classifier
                    final_probs[positions] = prob
                    layer_idxes[positions] = i

            return final_probs, layer_idxes, uncertain_infos


def infer_model(master_gpu_id, model, dataset, batch_size,
               use_cuda=False, num_workers=1, inference_speed=None, dump_info_file=None):
    global global_step
    global debug_break
    infer_dataloader = data.DataLoader(dataset=dataset,
    correct_sum = 0
    num_sample = infer_dataloader.dataset.__len__()
    predicted_probs = []
    true_labels = []
    infos = []"Inference Model...")
    stime_all = time.time()
    for step, batch in enumerate(tqdm(infer_dataloader, unit="batch", ncols=100, desc="Inference process: ")):
        texts = batch["texts"]
        tokens = batch["tokens"].cuda(master_gpu_id) if use_cuda else batch["tokens"]
        segment_ids = batch["segment_ids"].cuda(master_gpu_id) if use_cuda else batch["segment_ids"]
        attn_masks = batch["attn_masks"].cuda(master_gpu_id) if use_cuda else batch["attn_masks"]
        labels = batch["labels"].cuda(master_gpu_id) if use_cuda else batch["labels"]
        with torch.no_grad():
            probs, layer_idxes, uncertain_infos = model(tokens, token_type_ids=segment_ids, attention_mask=attn_masks,
                    inference=True, inference_speed=inference_speed)
        _, top_index = probs.topk(1)

        correct_sum += (top_index.view(-1) == labels).sum().item()
        if dump_info_file != None:
            for label, pred, prob, layer_idx, text in zip(labels, top_index.view(-1), probs, layer_idxes, texts):
                infos.append((label.item(), pred.item(), prob.cpu().numpy(), layer_idx.item(), text))
        if debug_break and step > 50:
    time_per = (time.time() - stime_all)/num_sample
    time_all = time.time() - stime_all
    acc = format(correct_sum / num_sample, "0.4f")"speed_arg:%s, time_per_record:%s, acc:%s, total_time:%s", 
                    inference_speed, format(time_per, '0.4f'), acc, format(time_all, '0.4f'))
    if dump_info_file != None and len(dump_info_file) != 0:
        with open(dump_info_file, 'w') as fw:
            for label, pred, prob, layer_i, text in infos:
                fw.write('\t'.join([str(label), str(pred), str(layer_i), text])+'\n')

    if probs.shape[1] == 2:
        labels_pr = [info[0] for info in infos]
        preds_pr = [info[1] for info in infos]
        precise, recall = eval_pr(labels_pr, preds_pr)"precise:%s, recall:%s", format(precise, '0.4f'), format(recall, '0.4f'))

CaoYiwei avatar Jan 06 '22 08:01 CaoYiwei



    optimizer_parameters = [
        {"params": [p for param_name, p in model.named_parameters() if not any(name in param_name for name in no_decay)], "weight_decay_rate": 0.01},
        {"params": [p for param_name, p in model.named_parameters() if any(name in param_name for name in no_decay)], "weight_decay_rate": 0.0}

CaoYiwei avatar Jan 25 '22 06:01 CaoYiwei