axolotl
axolotl copied to clipboard
Allow usage of DataCollatorForCompletionOnlyLM for SFT Training
⚠️ Please check that this feature request hasn't been suggested before.
- [X] I searched previous Ideas in Discussions didn't find any similar feature requests.
- [X] I searched previous Issues didn't find any similar feature requests.
🔖 Feature description
It would be great to support the DataCollatorForCompletionOnlyLM collator through the yaml config. In my experience, this often leads to higher quality model.
✔️ Solution
Adding an extra config called collator to the yaml where you can specify the collator to be used will allow users to finetune a whole new set of models and greatly help with experimentation. For this to work, it will need to support new hyperparams like packing and dataset_text_field since packing=True is not supported as of now with DataCollatorForCompletionOnlyLM.
❓ Alternatives
No response
📝 Additional Context
No response
Acknowledgements
- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this feature has not been requested yet.
- [X] I have provided enough information for the maintainers to understand and evaluate this request.
Coukd you provide some details about what this collator does differently?
The default collator for SFTTrainer uses the DataCollatorForLanguageModeling objective which finetunes the model on the instruction AND completion by applying MLM loss, whereas the DataCollatorForCompletionOnlyLM trains exclusively on the completion and NOT on the instruction. The best explanation I know of is here.
we currently already mask out the instruction when train_on_inputs: false is set. Does that sound like it is the same at the end of the day?
In reading the docs for the basic data collator (https://huggingface.co/docs/transformers/main_classes/data_collator), it seems like the primary responsibility of the collator is to take multiple rows and build out the batch. It seems the trl collator tries to do several things, which can make it harder to compose functionality given the various configurations we support.
@winglian Agreed!
train_on_inputs: false is what I really wanted to achieve. Could you point me to where that is actually happening in the code. I couldn't find it myself?