CoNT
CoNT copied to clipboard
function `torch_bleu` producing inappropriate results
Hi,
While playing around, I have found an error in the function torch_bleu
which is used to rank batch-negatives and beam-positives.
In model.model.CoNTGenerator.torch_bleu
(line 47-70),
there is an severe mistake which results in wrong BLEU scores, certainly when n_gram == 1
and possibly when n_gram >= 2
(rare case where token indices are propotional; i.e. 2-gram [4, 8] and [34, 68]).
Current line 66-67:
sim_matrix = torch.cosine_similarity(input_tensor2_4gram.unsqueeze(3), input_tensor1_4gram.unsqueeze(2),
dim=-1) >= 1.0
Suggestion:
sim_matrix = torch.norm( # Calculate L2 norm to find if N-gram in `sys`` is present in `ref``
input_tensor2_4gram.unsqueeze(3) - input_tensor1_4gram.unsqueeze(2),
p=2,
dim=-1
) == 0.0