Deprecated `flash_attn_fuse_qkv` toggle for the latest transformers (4.51.3 or lower)
⚠️ Please check that this feature request hasn't been suggested before.
- [x] I searched previous Ideas in Discussions didn't find any similar feature requests.
- [x] I searched previous Issues didn't find any similar feature requests.
🔖 Feature description
The flash_attn_fuse_qkv toggle is deprecated due to the update of the transformers library.
This feature functions by replacing the forward() function of LlamaAttention/MistralAttention with fused implementations (see llama and mistral). However, the attention forward() has changed a lot in the latest (as well as recent) versions of transformers, breaking the flash_attn_fuse_qkv feature.
File "/home/root/miniconda3/envs/text-axolotl/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 318, in forward
hidden_states, self_attn_weights = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/root/miniconda3/envs/text-axolotl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/root/miniconda3/envs/text-axolotl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: flashattn_forward() got an unexpected keyword argument 'cache_position'
✔️ Solution
To solve the problem, the flashattn_forward() functions should be adjusted.
❓ Alternatives
A workaround is to add a version limit of the transformers library when this feature is on.
📝 Additional Context
No response
Acknowledgements
- [x] My issue title is concise, descriptive, and in title casing.
- [x] I have searched the existing issues to make sure this feature has not been requested yet.
- [x] I have provided enough information for the maintainers to understand and evaluate this request.
Thanks for the report! Would you be interested in adding a deprecation error for that config in a PR?
I'm not sure if checking transformers version would work as our code is usually based around latest transformers, and using an older version would likely have some other import errors.