axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

FlashAttention V3

Open casper-hansen opened this issue 7 months ago • 11 comments

⚠️ Please check that this feature request hasn't been suggested before.

  • [x] I searched previous Ideas in Discussions didn't find any similar feature requests.
  • [x] I searched previous Issues didn't find any similar feature requests.

🔖 Feature description

Flash Attention v3 can speedup training by 1.5-2.0x by properly utilizing the Hopper instructions.

Therefore, it makes a lot of sense to start integrating it into axolotl now that v3 is ready to be used. Ideally, I think v3 should still be compatible with sample packing since it has the same interface.

✔️ Solution

IBM recently implemented FAv3 in their dolomite engine alongside FAv2. I think this is a good example which axolotl maintainers could be inspired by. https://github.com/IBM/dolomite-engine/commit/03d828e4b877ff4232c5a739bb99104b24c71807

Install from here: https://github.com/Dao-AILab/flash-attention/tree/main/hopper

❓ Alternatives

No response

📝 Additional Context

No response

Acknowledgements

  • [x] My issue title is concise, descriptive, and in title casing.
  • [x] I have searched the existing issues to make sure this feature has not been requested yet.
  • [x] I have provided enough information for the maintainers to understand and evaluate this request.

casper-hansen avatar Apr 30 '25 05:04 casper-hansen

Is the API exactly the same? Just from quick look, I see dropout not passed in v3.

What do we need particular to change on our end, or is it just installation method that differs? I saw that it mentions specialized for Hopper, so would it not work for non-Hopper arch?

NanoCode012 avatar Apr 30 '25 09:04 NanoCode012

The API is close to being the same as in v2 and is almost a drop-in replacement, but not completely as you outlined. On your end, you need to install it and properly apply the right functions.

Yes, it's specialized for Hopper hardware which is why it's able to achieve 1.5-2.0x.

casper-hansen avatar Apr 30 '25 14:04 casper-hansen

Do you know whether FA3 dropped support for any non-hopper arch compared to FA2? I recall FA2 dropped for Turing, and support was still not added back.

NanoCode012 avatar Apr 30 '25 14:04 NanoCode012

FA3 is specifically developed for Hopper because the architecture has new instructions that previous architectures do not.

casper-hansen avatar Apr 30 '25 14:04 casper-hansen

Small script to compile the kernels is seen below. I think this has a lot of potential :)

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
pip install ninja packaging
FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install

NOTE: updated build command to exclude building SM80 and FP8 to speed up compilation since both these features will not be used.

casper-hansen avatar May 16 '25 09:05 casper-hansen

Small script to compile the kernels is seen below. I think this has a lot of potential :)

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
pip install ninja packaging
MAX_JOBS=128 python setup.py install

Since this is hopper-only, one method could be having a specific hopper docker image.

We still need to handle any API signature changes though.

NanoCode012 avatar May 16 '25 10:05 NanoCode012

@NanoCode012 There are no API changes except missing features. So you really want both v2 and v3 installed for this to work.

I would suggest this code to test things out! (flash_attn = v2 and flash_attn_interface = v3)

def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
    if has_remote_code:
        patch_remote(model_name)
    elif hasattr(transformers, "modeling_flash_attention_utils"):
        transformers.modeling_flash_attention_utils._get_unpad_data = (  # pylint: disable=protected-access
            get_unpad_data
        )

        # Monkeypatch flash_attn_func and flash_attn_varlen_func from flash_attn_interface
        try:
            from flash_attn_interface import flash_attn_func as flash_attn_func_v3
            from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3

            transformers.modeling_flash_attention_utils.flash_attn_func = flash_attn_func_v3
            transformers.modeling_flash_attention_utils.flash_attn_varlen_func = flash_attn_varlen_func_v3
            print("Successfully patched flash_attn_func and flash_attn_varlen_func with flash_attn_interface (v3)")
        except ImportError:
            print("flash_attn_interface not installed. Skipping flash_attn_func monkeypatch.")

    if model_type == "mixtral" and is_deepspeed_zero3_enabled():
        patch_mixtral_moe_forward_zero3()

casper-hansen avatar May 16 '25 10:05 casper-hansen

We'll address this in two parts. First is to ship the base image with it preinstalled. #2685. Then we can tackle the patch and make sure CI runs against it. We should also upstream this fix.

winglian avatar May 16 '25 20:05 winglian

@casper-hansen That didn't work out of the box for you i'm assuming? looks like the FA3 api doesn't support dropout_p and changed the outputs?

winglian avatar May 18 '25 15:05 winglian

I never actually got to finish building it, I got impatient because it takes a long time and wanted to do other stuff.

  • dropout_p: yes, this is dropped. not a big deal IMO, not a lot of models use this, only very old one's
  • outputs: slightly different. you get (attn_out, softmax_lse) in v3 vs just (attn_out) in v2.

casper-hansen avatar May 18 '25 15:05 casper-hansen

Transformers added FA v3 support upstream. I think we just need to add a change and set the attn_implementation now. Could be extended work after attention refactor is in

NanoCode012 avatar Aug 05 '25 05:08 NanoCode012