EasyEdit icon indicating copy to clipboard operation
EasyEdit copied to clipboard

Question about Evaluation implement details

Open asdfo123 opened this issue 1 year ago • 1 comments

Hi, Thanks a lot for your great work! Recently I encounter the same question as #114 , but I still don't understand the logic behind your reply.
I quote:

In your easyeditor/evaluate/evaluate_utils.py, when you test the "prediction" acc, why you input prompt_target_tok which has already contained the expected prediction target? Shouldn't it be inputted prompt_tok to get the real model output? I am a beginner and it makes me confused. I really wish to get an answer, thank you!

def test_prediction_acc(model, tok, hparams, prompts, targets, device, locality=False):
    if isinstance(prompts, str):
        prompts,targets = [prompts,], [targets,]
    prompt_target = [prompt + ' ' + target for prompt, target in zip(prompts,targets)]
    max_prompt_len = max([len(tok.encode(_)) for _ in prompt_target]) + 1
    prompt_target_tok = tok(
        prompt_target,
        padding=True,
        truncation=True,
        max_length=max(hparams.max_length, max_prompt_len),
        return_tensors="pt",
    ).to(f"cuda:{device}")
    prompt_tok = tok(
        prompts,
        padding=True,
        truncation=True,
        max_length=max(hparams.max_length, max_prompt_len),
        return_tensors="pt",
    )
    num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in prompt_tok['input_ids']]
    num_pad_toks = [int((i == tok.pad_token_id).sum()) for i in prompt_target_tok['input_ids'].cpu()]
    prompt_len = [x+y for x,y in zip(num_pad_toks,num_prompt_toks)]
    with torch.no_grad():
        outputs = model(**prompt_target_tok)
        if type(outputs) is torch.Tensor:
            logits = outputs
        else:
            logits = outputs.logits
        answers = torch.argmax(logits, dim=-1).squeeze().detach().cpu().numpy().tolist()
        labels = prompt_target_tok['input_ids'].squeeze().detach().cpu().numpy().tolist()
        answers = slice_list(answers,prompt_len,left=True)
        labels = slice_list(labels,prompt_len,left=False)
        if locality:
            return answers if type(answers[0]) is list else [answers,]
        if isinstance(answers[0], list):
            res = []
            for ans,label in zip(answers,labels):
                temp_acc = np.mean(np.equal(ans, label))
                if np.isnan(temp_acc):
                    continue
                res.append(temp_acc)
            return res
        else:
            return [np.mean(np.equal(answers, labels))]

Thanks for your attention. You can refer to the inference of the decoder-only model, which is generally to execute the next token prediction. (e.g. I want to have -> predict lunch at the location of have). So for the expected prediction target, we execute:

input_prompt ---(predict)---> target_token_1
input_prompt +  target_token_1 ---(predict)---> target_token_2
input_prompt +  target_token_1 +  target_token_2---(predict)---> target_token_3
......

Originally posted by @pengzju in https://github.com/zjunlp/EasyEdit/issues/114#issuecomment-1854998369

Could you please elaborate a little more? Thanks again.

asdfo123 avatar Sep 29 '24 08:09 asdfo123

In essence, this evaluation method involves concatenating the question (Q) and answer (A) and feeding it into the language model for inference. Since the decoder-only model's attention is causal, it's essentially predicting the next token. Thus, the metric being tested is the probability of outputting edited knowledge through the prediction of the next token. This can be represented as follows:

input_prompt ---(LM predict)---> target_token_1
input_prompt + target_token_1 ---(LM predict)---> target_token_2
input_prompt + target_token_1 + target_token_2 ---(LM predict)---> target_token_3

pengzju avatar Sep 29 '24 11:09 pengzju

hi, do you have any further issues?

zxlzr avatar Oct 07 '24 14:10 zxlzr

hi, do you have any further issues?

No, I have understood after learning knowledge about the relationship between logits and tok['input_ids']. Thanks for your reply!

asdfo123 avatar Oct 11 '24 08:10 asdfo123