Chinese-Text-Classification-Pytorch
Chinese-Text-Classification-Pytorch copied to clipboard
请问是否可以写一个预测类
自己写了个函数:
if args.status == 'predict':
line = "union select or 1=1"
vocab = build_vocab(train_path, args.word, args.vocab_size, args.min_freq)
n_vocab = len(vocab)
model = Model(n_vocab, args.num_classes, args.embed, args.num_filters, args.filter_sizes, args.dropout).to(device)
print(model.parameters)
# print(model.predict(vocab, args.word, line))
#predict(model, vocab, args.word, args.max_sequence_length, line, '1', device)
if args.word:
tokenizer = lambda x: x.split(' ')
else:
tokenizer = lambda x: [y for y in x]
content = tokenizer(line)
data = [vocab[x] for x in content if x in vocab]
data = kr.preprocessing.sequence.pad_sequences([data], 32)
tests = [np.asarray(data)]
test_batch = Variable(torch.LongTensor(tests))
# 其一
predict = model(test_batch).data.max(1, keepdim=True)[1]
label = predict[0][0]
print(label)
#其二
outputs = model(test_batch)
predic = torch.max(outputs.data, 1)[1].cpu().numpy()
print(predic)
自己写了个函数:
if args.status == 'predict': line = "union select or 1=1" vocab = build_vocab(train_path, args.word, args.vocab_size, args.min_freq) n_vocab = len(vocab) model = Model(n_vocab, args.num_classes, args.embed, args.num_filters, args.filter_sizes, args.dropout).to(device) print(model.parameters) # print(model.predict(vocab, args.word, line)) #predict(model, vocab, args.word, args.max_sequence_length, line, '1', device) if args.word: tokenizer = lambda x: x.split(' ') else: tokenizer = lambda x: [y for y in x] content = tokenizer(line) data = [vocab[x] for x in content if x in vocab] data = kr.preprocessing.sequence.pad_sequences([data], 32) tests = [np.asarray(data)] test_batch = Variable(torch.LongTensor(tests)) # 其一 predict = model(test_batch).data.max(1, keepdim=True)[1] label = predict[0][0] print(label) #其二 outputs = model(test_batch) predic = torch.max(outputs.data, 1)[1].cpu().numpy() print(predic)
在本项目中用不了