transformer
transformer copied to clipboard
batch.trg[j] out of index.
in train.py the size of batch.trg is [118, 35]. the for loop will definitely lead to out of bounds.
total_bleu = []
for j in range(batch_size):
try:
trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
output_words = output[j].max(dim=1)[1]
output_words = idx_to_word(output_words, loader.target.vocab)
bleu = get_bleu(hypothesis=output_words.split(), reference=trg_words.split())
total_bleu.append(bleu)
except:
pass
so, is it better to use for j in range(batch.trg.shape[0])
here?
I agree with you.