Bert-Chinese-Text-Classification-Pytorch icon indicating copy to clipboard operation
Bert-Chinese-Text-Classification-Pytorch copied to clipboard

预测代码

Open kalayan opened this issue 3 years ago • 4 comments

`import os import time import torch import pandas as pd import numpy as np from train_eval import train, init_network from importlib import import_module import argparse from utils import build_dataset, build_iterator, get_time_dif

dataset = 'THUCNews' # 数据集 model_name = 'bert' x = import_module('models.' + model_name) config = x.Config(dataset)

model = x.Model(config).to(config.device) model.load_state_dict(torch.load('model67'))

class DatasetIterater(object): def init(self, batches, batch_size, device): self.batch_size = batch_size self.batches = batches self.n_batches = len(batches) // batch_size self.residue = False # 记录batch数量是否为整数 if len(batches) % self.n_batches != 0: self.residue = True self.index = 0 self.device = device

def _to_tensor(self, datas):
    x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
    y = torch.LongTensor([_[1] for _ in datas]).to(self.device)

    # pad前的长度(超过pad_size的设为pad_size)
    seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
    mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
    return (x, seq_len, mask), y

def __next__(self):
    if self.residue and self.index == self.n_batches:
        batches = self.batches[self.index * self.batch_size: len(self.batches)]
        self.index += 1
        batches = self._to_tensor(batches)
        return batches

    elif self.index >= self.n_batches:
        self.index = 0
        raise StopIteration
    else:
        batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
        self.index += 1
        batches = self._to_tensor(batches)
        return batches

def __iter__(self):
    return self

def __len__(self):
    if self.residue:
        return self.n_batches + 1
    else:
        return self.n_batches

def prediction_model(content): pad_size=32 PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号 label = 0 token = config.tokenizer.tokenize(content) token = [CLS] + token seq_len = len(token) mask = [] token_ids = config.tokenizer.convert_tokens_to_ids(token) if len(token) < pad_size: mask = [1] * len(token_ids) + [0] * (pad_size - len(token)) token_ids += ([0] * (pad_size - len(token))) else: mask = [1] * pad_size token_ids = token_ids[:pad_size] seq_len = pad_size

msg_tuple = [(token_ids, int(label), seq_len, mask)]
test_iter = DatasetIterater(msg_tuple, config.batch_size, config.device)
with torch.no_grad():
    for texts, labels in test_iter:
        outputs = model(texts)
        predic = torch.max(outputs.data, 1)[1].cpu().numpy()

print(predic[0])

return predic[0]

prediction_model('这是要分类的文字')`

kalayan avatar Mar 15 '21 09:03 kalayan

请问一下,如果我想连续加载两个模型,加载完第一个之后应该怎么重置torch,才能让第二个模型加载的时候不受第一个模型的影响

shengtaovvv avatar Mar 19 '21 07:03 shengtaovvv

model.load_state_dict(torch.load('model67')) 这部分的model67应该换成本项目的哪个部分

vencentDebug avatar Apr 16 '21 02:04 vencentDebug

请问这个代码为什么每次预测结果不一样哈?

Drizzlenum avatar Apr 13 '23 03:04 Drizzlenum