axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

Allow usage of DataCollatorForCompletionOnlyLM for SFT Training

Open mlmonk opened this issue 1 year ago • 6 comments

⚠️ Please check that this feature request hasn't been suggested before.

  • [X] I searched previous Ideas in Discussions didn't find any similar feature requests.
  • [X] I searched previous Issues didn't find any similar feature requests.

🔖 Feature description

It would be great to support the DataCollatorForCompletionOnlyLM collator through the yaml config. In my experience, this often leads to higher quality model.

✔️ Solution

Adding an extra config called collator to the yaml where you can specify the collator to be used will allow users to finetune a whole new set of models and greatly help with experimentation. For this to work, it will need to support new hyperparams like packing and dataset_text_field since packing=True is not supported as of now with DataCollatorForCompletionOnlyLM.

❓ Alternatives

No response

📝 Additional Context

No response

Acknowledgements

  • [X] My issue title is concise, descriptive, and in title casing.
  • [X] I have searched the existing issues to make sure this feature has not been requested yet.
  • [X] I have provided enough information for the maintainers to understand and evaluate this request.

mlmonk avatar Jan 05 '24 17:01 mlmonk

Coukd you provide some details about what this collator does differently?

winglian avatar Jan 05 '24 17:01 winglian

The default collator for SFTTrainer uses the DataCollatorForLanguageModeling objective which finetunes the model on the instruction AND completion by applying MLM loss, whereas the DataCollatorForCompletionOnlyLM trains exclusively on the completion and NOT on the instruction. The best explanation I know of is here.

mlmonk avatar Jan 05 '24 17:01 mlmonk

we currently already mask out the instruction when train_on_inputs: false is set. Does that sound like it is the same at the end of the day?

winglian avatar Jan 05 '24 18:01 winglian

In reading the docs for the basic data collator (https://huggingface.co/docs/transformers/main_classes/data_collator), it seems like the primary responsibility of the collator is to take multiple rows and build out the batch. It seems the trl collator tries to do several things, which can make it harder to compose functionality given the various configurations we support.

winglian avatar Jan 06 '24 15:01 winglian

@winglian Agreed! train_on_inputs: false is what I really wanted to achieve. Could you point me to where that is actually happening in the code. I couldn't find it myself?

mlmonk avatar Jan 12 '24 23:01 mlmonk