alpaca-lora icon indicating copy to clipboard operation
alpaca-lora copied to clipboard

train_on_inputs clarification

Open ElleLeonne opened this issue 1 year ago • 1 comments

When you say train_on_inputs = False, I presume you mean to mask out the prompt, and train the loss only on the response that the model is supposed to produce.

This is made slightly confusing by the fact that the prompt itself has an "input" field. Signposting would seem to imply that you're only masking out the portion of the json with key "input". This should be changed for brevity.

I propose "train_on_inputs" should be renamed to "train_on_prompt"

ElleLeonne avatar Apr 02 '23 19:04 ElleLeonne

Here is the relevant code. I'm not 100% on this python syntax... But it seems like it's sticking -100 for the labels of the tokenized prompt for the entirety of the length of instruction+input, and leaving the rest intact (the output)

Or in other words it keeps the outputs, but eliminates the instruction and input via a mask.

def generate_and_tokenize_prompt(data_point):
    full_prompt = prompter.generate_prompt(
        data_point["instruction"],
        data_point["input"],
        data_point["output"],
    )
    tokenized_full_prompt = tokenize(full_prompt)
    if not train_on_inputs:
        user_prompt = prompter.generate_prompt(
            data_point["instruction"], data_point["input"]
        )
        tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
        user_prompt_len = len(tokenized_user_prompt["input_ids"])

        tokenized_full_prompt["labels"] = [
            -100
        ] * user_prompt_len + tokenized_full_prompt["labels"][
            user_prompt_len:
        ]  # could be sped up, probably
    return tokenized_full_prompt

jquave avatar Apr 11 '23 21:04 jquave