axolotl
axolotl copied to clipboard
FlashAttention V3
⚠️ Please check that this feature request hasn't been suggested before.
- [x] I searched previous Ideas in Discussions didn't find any similar feature requests.
- [x] I searched previous Issues didn't find any similar feature requests.
🔖 Feature description
Flash Attention v3 can speedup training by 1.5-2.0x by properly utilizing the Hopper instructions.
Therefore, it makes a lot of sense to start integrating it into axolotl now that v3 is ready to be used. Ideally, I think v3 should still be compatible with sample packing since it has the same interface.
✔️ Solution
IBM recently implemented FAv3 in their dolomite engine alongside FAv2. I think this is a good example which axolotl maintainers could be inspired by. https://github.com/IBM/dolomite-engine/commit/03d828e4b877ff4232c5a739bb99104b24c71807
Install from here: https://github.com/Dao-AILab/flash-attention/tree/main/hopper
❓ Alternatives
No response
📝 Additional Context
No response
Acknowledgements
- [x] My issue title is concise, descriptive, and in title casing.
- [x] I have searched the existing issues to make sure this feature has not been requested yet.
- [x] I have provided enough information for the maintainers to understand and evaluate this request.
Is the API exactly the same? Just from quick look, I see dropout not passed in v3.
What do we need particular to change on our end, or is it just installation method that differs? I saw that it mentions specialized for Hopper, so would it not work for non-Hopper arch?
The API is close to being the same as in v2 and is almost a drop-in replacement, but not completely as you outlined. On your end, you need to install it and properly apply the right functions.
Yes, it's specialized for Hopper hardware which is why it's able to achieve 1.5-2.0x.
Do you know whether FA3 dropped support for any non-hopper arch compared to FA2? I recall FA2 dropped for Turing, and support was still not added back.
FA3 is specifically developed for Hopper because the architecture has new instructions that previous architectures do not.
Small script to compile the kernels is seen below. I think this has a lot of potential :)
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
pip install ninja packaging
FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install
NOTE: updated build command to exclude building SM80 and FP8 to speed up compilation since both these features will not be used.
Small script to compile the kernels is seen below. I think this has a lot of potential :)
git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention/hopper pip install ninja packaging MAX_JOBS=128 python setup.py install
Since this is hopper-only, one method could be having a specific hopper docker image.
We still need to handle any API signature changes though.
@NanoCode012 There are no API changes except missing features. So you really want both v2 and v3 installed for this to work.
I would suggest this code to test things out! (flash_attn = v2 and flash_attn_interface = v3)
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
# Monkeypatch flash_attn_func and flash_attn_varlen_func from flash_attn_interface
try:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
transformers.modeling_flash_attention_utils.flash_attn_func = flash_attn_func_v3
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = flash_attn_varlen_func_v3
print("Successfully patched flash_attn_func and flash_attn_varlen_func with flash_attn_interface (v3)")
except ImportError:
print("flash_attn_interface not installed. Skipping flash_attn_func monkeypatch.")
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
We'll address this in two parts. First is to ship the base image with it preinstalled. #2685. Then we can tackle the patch and make sure CI runs against it. We should also upstream this fix.
@casper-hansen That didn't work out of the box for you i'm assuming? looks like the FA3 api doesn't support dropout_p and changed the outputs?
I never actually got to finish building it, I got impatient because it takes a long time and wanted to do other stuff.
- dropout_p: yes, this is dropped. not a big deal IMO, not a lot of models use this, only very old one's
- outputs: slightly different. you get (attn_out, softmax_lse) in v3 vs just (attn_out) in v2.
Transformers added FA v3 support upstream. I think we just need to add a change and set the attn_implementation now. Could be extended work after attention refactor is in