pytorch-seq2seq
pytorch-seq2seq copied to clipboard
A question about tutorial1.
I add a function(according the code in the tutorial4) to calculate the bleu score, but i get the vey low score(0.09), could you tell me why? This is code to calculate bleu:
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
model.eval()
if isinstance(sentence, str):
nlp = spacy.load('de')
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)
# src_mask = model.make_src_mask(src_tensor)
# print(src_tensor.size())
with torch.no_grad():
hidden, cell = model.encoder(src_tensor)
# print(hidden.size())
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
# trg = []
for i in range(max_len):
# trg_tensor = [1, 1]
trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device)
# print('trg_tensorf', trg_tensor.size())
# trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, _, _ = model.decoder(trg_tensor, hidden, cell)
# print(output.size())
# print('output', output.argmax(1).size())
pred_token = output.argmax(1).item()
# print('pred', pred_token.size())
# trg.append(pred_token)
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:]
def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
trgs = []
pred_trgs = []
for datum in data:
src = vars(datum)['src']
trg = vars(datum)['trg']
pred_trg = translate_sentence(src, src_field, trg_field, model, device, max_len)
#cut off <eos> token
pred_trg = pred_trg[:-1]
pred_trgs.append(pred_trg)
trgs.append([trg])
return bleu_score(pred_trgs, trgs)
The tutorial 1 model is supposed to be the "worst" out of all of the sequence-to-sequence models implemented in these tutorials, hence why it has a low BLEU score.
I'd have thought it would be a bit higher than 0.09 though - I'll look into it.