trl
trl copied to clipboard
Prototype Dataset Processor
This PR attempts to refactor and pull all tokenization logic out of the Trainer class. Having a separate tokenization process gives us higher visibility into what's being used in training, providing more clarified logic and reducing bugs. It attempts to do the following things.
# 1. PPO (prompt)
# 2. SFT (prompt + demonstration), there is also packing.
# 3. ✅ RM / DPO (chosen and rejected)
# 4. ✅ Visualization of length distributions?
# 5. ✅ Filter?
# * Smart truncation?
# 6. ✅ dataset_num_proc
# 7. check EOS token
# 8. dataset mixer?
# 9. ✅ pretty print that show tokenization?
# 10. hashable tokneization?
# 11. inputs / labels / attention_mask
# 12. always set a `tokenizer.pad_token_id`?
why?
Currently, the Trainer is also responsible for tokenization. It causes several issues:
-
duplicate tokenization steps. For example, alignment-handbook calls apply_chat_template(tokenize=False) for the dataset, followed by SFT/DPO trainer calling tokenized again. To remove duplication, we only needed to go through the dataset once by calling
apply_chat_template(tokenize=True)
-
truncation logic happens in various places and is hard to predict. SFTTrainer calls it the
max_seq_length
, RewardModeling calls itmax_length
, DPO/KTOTrainers call itmax_length
,max_prompt_length
,max_target_length
. There are also different truncation logics. E.g., [(truncate the prompt if prompt + chosen is too long)] (https://github.com/huggingface/trl/blob/99f2c94b2200927a1dc156f16e012dca11f865e1/trl/trainer/dpo_trainer.py#L797-L799). This causes issue like https://huggingface.slack.com/archives/C04EX6W3QSY/p1715255460198239 as raised by @abhishekkrthakur.- the hard truncation logic seems debatable: if the sequence length is too long, shouldn't we filter them out instead of giving a truncated response? The truncated response could be an incomplete code snippet / summaries (basically bad data). If truncation is really desired, we should do some kind of smart truncation like truncate at the last paragraph, so the sentences are still complete.
-
learning to generate EOS tokens. https://github.com/huggingface/trl/issues/1623#issuecomment-2113230864 suggested that EOS tokens always 1) correspond to -100 in the labels and 2) if the dataset contains the EOS token before collating, then the attention mask of EOS token is also 1. It's possible that the model may never learn to generate EOS tokens.
- what's a bit unclear to me is how zephyr learns to output EOS tokens, despite all the labels of EOS token are marked with -100 and are being masked out. My suspicion is that the attention_mask=1 plays some roles in it.
-
dataset_num_proc
is not uniformly applied, as a result #1624 is needed. There is also the question of hashable tokenization -
Dataset mixer (e.g., in our h4 codebase), that should be more widely available to use in TRL and can be combined with this class.
The current design
The current design roughly looks like this. Note that we can still put it in Trainer.__init__
so users don't have to configure it directly.
dataset_config = DatasetConfig(max_token_length=1024, max_prompt_token_lenth=128)
dataset_processor = PreferenceDatasetProcessor(tokenizer=tok, config=dataset_config)
train_dataset = dataset_processor.tokenize(preference_datasets["train"])
stats = dataset_processor.get_token_length_stats(train_dataset)
pprint.pp(stats)
train_dataset = dataset_processor.filter(train_dataset)
stats = dataset_processor.get_token_length_stats(train_dataset)
pprint.pp(stats)
dataset_processor.get_token_length_visualization(train_dataset)
print(tok.decode(train_dataset[0]["chosen"]))
visualize_token(train_dataset[0]["chosen"], tok)