transformers
transformers copied to clipboard
Return assistant generated tokens mask in apply_chat_template
What does this PR do?
This PR addresses issue #28950 and enhances the functionality of the tokenizer.apply_chat_template method when finetuning on chat datasets.
The method tokenizer.apply_chat_template is recommended for maintaining consistency with the model's original template during both training and inference phases. This practice ensures that conversations are processed in a uniform manner.
Moreover, during the finetuning process on chat datasets, it is crucial to exclude tokens from the "user" or "system" segments of the conversation. This exclusion is necessary because including these tokens would train the model to predict not only the "assistant" responses but also potential user queries, which is undesirable (and strange).
Currently, the tokenizer.apply_chat_template method does not provide a way to identify which tokens belong to the "assistant" response. To address this, the PR introduces a new parameter called return_assistant_mask. This parameter returns a mask that identifies tokens generated by the assistant, allowing for the appropriate creation of a labels arrays with ignore (-100) values during training.
Additionally, this PR proposes the introduction of a new keyword generation (name open for discussion) in the jinja2 chat template. This keyword is used to encapsulate the assistantβs response within your chat template.
Here is an example of the new api:
template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant\nmessage"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
]
output = tokenizer_r.apply_chat_template(
dummy_conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_mask=True,
return_dict=True,
)
labels = [output["input_ids"][index] if mask == 1 else -100 for index, mask in enumerate(output["assistant_mask"])]
There are some issues I would want to discuss during this pr:
- Is this API fine? maybe we should return the a
labelskey in the dict already and not bother with the intermediate mask. - Name of the new tag? currently
generationbut maybe should beassistant_response? or anything you like. - I think maybe I should add a warning if a user runs with
return_assistant_maskbut the tokenizer chat template hasn't changed yet to support this new tag. That way users will know the are probably training on wrong tokens. - In 99% of finetuning examples I see people using the trl trainer with
packing=True. My new changes wont be usable easily if people use that parameter and maybe we should think of my API while taking into consideration a refactor of thepackingaffect.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
cc @lewtun and @xenova to this as well!
My thoughts on your questions:
Is this API fine? maybe we should return the a labels key in the dict already and not bother with the intermediate mask.
I prefer just returning the labels with masking applied, rather than returning the mask for the user to apply.
Name of the new tag? currently
generationbut maybe should beassistant_response? or anything you like.
I think generation is fine - assistant_response is very long!
I think maybe I should add a warning if a user runs with return_assistant_mask but the tokenizer chat template hasn't changed yet to support this new tag. That way users will know the are probably training on wrong tokens.
Agreed! I guess the easiest way to check this is to just do a string search for {% generation %} tags? Be careful, because you'll also need to check for variants like {-
n 99% of finetuning examples I see people using the trl trainer with
packing=True. My new changes wont be usable easily if people use that parameter and maybe we should think of my API while taking into consideration a refactor of thepackingaffect.
Yes, there's already a DataCollatorForCompletionOnlyLM which also requires packing=False. I feel like we can slot in with that easily enough!
I want to hear from @xenova and ideally someone using minijinja as well, though - how easily can we support this extension? Since it's only useful in training, maybe it's less critical to have it in huggingface/jinja or TGI, but at the very least we should be able to gracefully ignore the generation tags.
I prefer just returning the labels with masking applied, rather than returning the mask for the user to apply.
I agree, but then what should be the ignore label? -100 (pytorch)?. Im not sure its a good idea to add another parameter ignore_label
I think -100 is correct, yes! This is the standard value for Torch and Transformers, so we don't need an extra arg to change it.
yea i just thought of non pytorch users where -100 is not the default.
Anyways I updated the code to return labels
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@yonigottesman Thanks for working on this, this is a feature I am very much looking forward to. Hope this can be merged soon.
Yes, sorry for not checking in @yonigottesman! Do you have other features you want to add, or should we treat this as review-ready?
@Rocketknight1 this is ready to be reviewed yes :)
On it!
@yonigottesman while I'm reviewing can you rebase/resolve the merge conflict? It's nothing major, but it'll block us merging the PR until it's ready. (Edit: Probably better to rebase because your branch is a little out of date by now, a rebase will catch any other issues before merging)
I agree it should be assistant_mask and not labels. I feel like the collator should be added here and not trl what do you think?
Yes, agree! It's also fine to leave that for a separate PR, and just add the mask functionality in this PR.
ok. fixed to now return mask
Got it! Ping me whenever you're ready for re-review.
ready π
This is amazing!! Looking forward to this new change!
fixed your suggestions. Do you think the docstring should contain a small example ? like this is the phi template with the new token:
"{{ bos_token }}"
"{% for message in messages %}"
"{% if (message['role'] == 'user') %}"
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"
"{% elif (message['role'] == 'assistant') %}"
"{%generation%}"
"{{message['content'] + '<|end|>' + '\n'}}"
"{%endgeneration%}"
"{% endif %}"
"{% endfor %}"
Yes! We could also add a section to the end of the chat template documentation, under "Template Writing Tips" - it's a good way to make this feature visible to people. The source for it is en/chat_templating.md
I've updated that doc recently, though - make sure you rebase before editing it!
@yonigottesman Could you rebase on main to include upstream changes? This should resolve a lot of the current CI failures
@amyeroberts done
thanks @amyeroberts. You are right the whole sharing of the two lists between the AssistantTracker and the code outside is fishy and dangerous. I came up with a more pythonic solution with context managers tell me what you think.
This way the lists are only created and managed outside the object and you need to "activate" the object with the lists to get it to work. no encapsulation broken by code touching inner object members, and kind of safe as long as we use it with with statement
@amyeroberts anything else I should update? π
@yonigottesman There's been a update on main which should fix the hub tests. Could you try rebasing, this should hopefully resolve
Amazing contribution! π π π It helps me a lot!
Here is an example of the new api:
template = ( "{% for message in messages %}" "{% if (message['role'] != 'assistant') %}" "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" "{% elif (message['role'] == 'assistant')%}" "{{'<|im_start|>' + message['role'] + '\n'}}" "{% generation %}" "{{message['content'] + '<|im_end|>'}}" "{% endgeneration %}" "{{'\n'}}" "{% endif %}" "{% endfor %}" ) dummy_conversation = [ {"role": "system", "content": "system message"}, {"role": "user", "content": "user message"}, {"role": "assistant", "content": "assistant\nmessage"}, {"role": "user", "content": "user message 2"}, {"role": "assistant", "content": "assistant message 2"}, ] output = tokenizer_r.apply_chat_template( dummy_conversations, chat_template=dummy_template, tokenize=True, return_assistant_mask=True, return_dict=True, ) labels = [output["input_ids"][index] if mask == 1 else -100 for index, mask in enumerate(output["assistant_mask"])]
Some spelling mistakes in this example:
output = tokenizer_r.apply_chat_template( dummy_conversations, chat_template=dummy_template, tokenize=True, return_assistant_mask=True, return_dict=True, )
dummy_conversation instead of dummy_conversations, template instead of dummy_template
labels = [output["input_ids"][index] if mask == 1 else -100 for index, mask in enumerate(output["assistant_mask"])]
assistant_masks instead of assistant_mask
Thank you for your work on this!
I'm having some issues though. When I run the example script from the tests, I don't seem to get any assistant tokens:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]
output = tokenizer.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
print("".join(map(str, output["assistant_masks"])))
For me, this prints out 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
I think this bug is being caused by other tokens being printed out before the {% generation %}, within the same turn. For example, if I change the chat template to:
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{% generation %}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
it works correctly, printing out 0000000000000000000000000000000011111111111111111111111110000000000000000001111111111111111111111111
I am running the latest version of transformers, 4.44.0
@avicooper1 there is a bug but its not about tokens before "generation" in the same turn. If you try a different tokenizer it will work.
There is something strange about the llama3 tokenizer (PreTrainedTokenizerFast) for some reason the char_to_token function isn't working as expected and my implementation is based on its result.
I opened an issue huggingface/tokenizers#1620.
Given the issue, is there a workaround to get the assistant mask?
sadly for llama3 i dont think so :(
other models that use the same tokenizer class PreTrainedTokenizerFast (but different config) do work for example tiiuae/falcon-mamba-7b-instruct. so i guess its something specific to the llama3 configuration
Unfortunately the fact that the template needs to contain the {% generation %} part makes it very inflexible to use. Would it be somehow possible to just generate the mask base on the provided user assistant inputs?