DoLa
DoLa copied to clipboard
The way to calculate the log_probs
Hi, As the output of the model in each token's position represents the possibilities of next token, should the calculation of log_probs be misaligned. I mean "diff_logits[range(diff_logits.shape[0]-1), continue_ids[1:]].sum().item()" instead of "log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()".