Replacing .sum() with .mean() in dola.py?
Hi Team, great work on the project!
I noticed something in dola.py line 217 that might need attention. The code is:
log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
I am wondering if .sum() should be replaced with .mean()? Since false answers and correct answers may have different number of tokens, using .mean() might provide a more fair comparison.
Thanks!
I have the same question and look forward to getting an answer. Thank you.
I think both implementations should be fine. In theory, we should not apply .mean() because .sum() is the true estimation of the probability of the whole sentence: $P(x_1, x_2, ..., x_N) = \Sigma^N_{i=1} P(x_i | x_1, x_2, ..., x_{i-1} )$ There is no .mean() operation in this equation.
In engineering consideration, people apply .mean() because language models are not perfect and will be biased by the sentence length if there are very different answer sentence lengths (both short answers and long answers) in the candidates.
In my experiment, I observed almost no changes when switching from .sum() to .mean(). I think it's because all the answer sentences have very close lengths in the datasets I tested. So I just kept the .sum() implementation. You can also switch to .mean() if your dataset has very different answer sentence lengths.