Bert-Chinese-Text-Classification-Pytorch
Bert-Chinese-Text-Classification-Pytorch copied to clipboard
感谢分享。然而菜鸡如我,写不出一个单文本的predict接口
感谢楼主分享,训练好了模型,从test效果感觉还行。只是写不出predict脚本。有哪位大佬写过求指教
def predict(config,model,data_iter):
model.load_state_dict(torch.load(config.save_path))
model.eval()
predict_all = np.array([],dtype=int)
with torch.no_grad():
for texts,labels in data_iter:
outputs = model(texts)
predic = torch.max(outputs.data,1)[1].cpu().numpy()
predict_all = np.append(predict_all,predic)
return predict_all
#70 预测代码