trl
trl copied to clipboard
conversational data for SFTTrainer
For SFTTrainer, if we load the dataset using a conversational form (ChatML format), the function apply_chat_template
is used (https://github.com/huggingface/trl/blob/v0.7.11/trl/extras/dataset_formatting.py#L55) with tokenize=False
. Later in SFTTrainer, the data is tokenized again with add_special_tokens=True
. In tokenizer like LLaMATokenizer, there will be two bos tokens at the very beginning: <s><s> ...
, which is not intended. Maybe we should modify dataset_kwargs
at this line https://github.com/huggingface/trl/blob/v0.7.11/trl/trainer/sft_trainer.py#L246 so that dataset_kwargs['add_special_tokens']=True
?
Yes that would make sense, would you like to open a PR for the fix? cc @philschmid what do you think?
sure I will do that
Good idea @edixiong, thats what I currently do manually. https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#4-fine-tune-llm-using-trl-and-the-sfttrainer
This probably should only be applied if the "chatml" or template is detected.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.