Bert-Chinese-Text-Classification-Pytorch
Bert-Chinese-Text-Classification-Pytorch copied to clipboard
关于预测代码的提问
预测代码如下:
import os
import torch
import torch.nn as nn
from pytorch_pretrained_bert import BertModel, BertTokenizer
# 识别的类型
key = {0: '别名',
1: '防治农药',
2: '病原学名',
3: '病原中文名',
4: '病原属性',
5: '为害部位',
6: '为害作物',
7: '属目',
8: '属科',
9: '学名'
}
class Config:
"""配置参数"""
def __init__(self):
cru = os.path.dirname(__file__)
self.class_list = [str(i) for i in range(len(key))] # 类别名单
self.save_path = os.path.join(cru, 'ernie/ERNIE.ckpt')
self.device = torch.device('cpu')
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.batch_size = 128 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.bert_path = os.path.join(cru, 'bert')
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
def build_dataset(self, text):
lin = text.strip()
pad_size = len(lin)
token = self.tokenizer.tokenize(lin)
token = ['[CLS]'] + token
token_ids = self.tokenizer.convert_tokens_to_ids(token)
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
return torch.tensor([token_ids], dtype=torch.long), torch.tensor([mask])
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
for param in self.bert.parameters():
param.requires_grad = True
self.fc = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
context = x[0]
mask = x[1]
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
out = self.fc(pooled)
return out
config = Config()
model = Model(config).to(config.device)
model.load_state_dict(torch.load(config.save_path, map_location='cpu'))
def prediction_model(text):
"""输入一句问话预测"""
data = config.build_dataset(text)
with torch.no_grad():
outputs = model(data)
num = torch.argmax(outputs)
return key[int(num)]
if __name__ == '__main__':
print(prediction_model('水稻恶苗病主要危害哪些部位?'))
假设我们使用bert+fc的网络结构,model = Model(config).to(config.device) 的时候需要加载bert原始参数值,model.load_state_dict(torch.load(config.save_path, map_location='cpu'))加载使用自己数据训练的模型参数,由于在训练过程中bert的参数具有梯度,是可训练的,那么我们训练好的模型应该是包含了全连接层的参数,同时也包含了bert在训练后的参数,那为什么还要加载bert的原始参数呢,这样做的话在部署的时候需要保证bert的原始模型文件和训练后模型文件同时存在,不是很理解。
同样不清楚