transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned

Open Abhishek-TAMU opened this issue 1 year ago • 13 comments

What does this PR do?

This PR removes the function call prepare_fa2_from_position_ids in flash_attention_forward as it causes graph break when torch_compile flag is turned on in Training arguments to use in SFTTrainer to perform padding free tuning of Llama model. This is because code in prepare_fa2_from_position_ids incur a cpu-gpu sync that is unavoidable. Hence cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q is now taken from the batch in DataCollatorForCompletionOnlyLM with this PR to avoid call to prepare_fa2_from_position_ids in flash_attention_forward.

CC: @ani300 @ArthurZucker

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

Abhishek-TAMU avatar Oct 03 '24 22:10 Abhishek-TAMU