keras icon indicating copy to clipboard operation
keras copied to clipboard

Flash attention support.

Open hazemessamm opened this issue 1 year ago • 1 comments
trafficstars

I added support for flash attention for PyTorch.

Let me know what do you think about this current implementation so I can add support for JAX and maybe will try for TF.

hazemessamm avatar Aug 22 '24 18:08 hazemessamm

Codecov Report

Attention: Patch coverage is 26.31579% with 14 lines in your changes missing coverage. Please review.

Project coverage is 78.85%. Comparing base (5aa5f88) to head (57e6e56). Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/nn.py 18.18% 8 Missing and 1 partial :warning:
keras/src/backend/numpy/nn.py 0.00% 1 Missing and 1 partial :warning:
keras/src/backend/tensorflow/nn.py 0.00% 1 Missing and 1 partial :warning:
keras/src/backend/jax/nn.py 66.66% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20152      +/-   ##
==========================================
+ Coverage   78.81%   78.85%   +0.04%     
==========================================
  Files         512      513       +1     
  Lines       49063    49250     +187     
  Branches     9035     9080      +45     
==========================================
+ Hits        38668    38837     +169     
- Misses       8530     8543      +13     
- Partials     1865     1870       +5     
Flag Coverage Δ
keras 78.71% <26.31%> (+0.04%) :arrow_up:
keras-jax 62.36% <21.05%> (+0.10%) :arrow_up:
keras-numpy 57.38% <10.52%> (-0.03%) :arrow_down:
keras-tensorflow 63.62% <10.52%> (+0.06%) :arrow_up:
keras-torch 62.35% <15.78%> (+0.09%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Aug 22 '24 18:08 codecov-commenter

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Sep 24 '24 02:09 github-actions[bot]

Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests?

hazemessamm avatar Oct 02 '24 19:10 hazemessamm

Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests?

In keras/src/ops/nn_test.py. Ops are tested through the op class in e.g. keras/src/ops/nn.py, rather than in a backend specific way.

fchollet avatar Oct 02 '24 21:10 fchollet

@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?

fchollet avatar Oct 05 '24 18:10 fchollet

@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?

It should be possible to consolidate this into dot_product_attention. That’s how it's implemented in torch, and I've seen a similar approach in jax (https://github.com/jax-ml/jax/blob/81a31f6adf453b2afc39936e15c15d8ad327bf6e/jax/_src/nn/functions.py#L1037-L1041)

As far as I know, for torch, flash attention is utilized if the conditions are met. For jax, we need to specify implementation="cudnn" to use it.

james77777778 avatar Oct 06 '24 03:10 james77777778

Very cool -- @hazemessamm can we do that, e.g. by adding a flash_attention argument in dot_product_attention? This makes it quite easy to also add support for JAX ( in addition to PyTorch). For TF I think we can skip support for now.

fchollet avatar Oct 06 '24 04:10 fchollet

Awesome work! Thank you.

Thank you, glad I could help.

hazemessamm avatar Oct 06 '24 19:10 hazemessamm

The test fails on torch + GPU:

FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution.

Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on?

fchollet avatar Oct 06 '24 22:10 fchollet

The test fails on torch + GPU:

FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution.

Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on?

I think flash attention in PyTorch does not work with any dtype except float16 and on specific GPUs, I just tested it on H100 GPU and it worked fine but it did not work on T4 GPU on Colab.

I also just found the following functions in PyTorch that we can use to check whether the inputs and the current GPU can use flash attention or not.

import torch
bsz, num_heads, seqlen, head_dim = 1, 2, 10, 16
query = torch.randn((bsz, num_heads, seqlen, head_dim), dtype=torch.float32, device='cuda:0')

params = torch.backends.cuda.SDPAParams(query, query, query, None, 16**-0.5, False)
is_flash_attention_enabled = torch.backends.cuda.can_use_flash_attention(params, False)
print(is_flash_attention_enabled) # Output: False, it will be true if `dtype=torch.float16`

If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend.

Documentation: https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.SDPAParams https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.can_use_flash_attention

hazemessamm avatar Oct 06 '24 23:10 hazemessamm

If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend.

That sounds great! Then, we can also skip the PyTorch unit test when this check evaluates to False.

fchollet avatar Oct 07 '24 01:10 fchollet

I skipped the tests for TensorFlow, NumPy and torch and I just tested JAX on T4 GPU on colab and I got this error: RuntimeError: Require at least Ampere arch to run, so we will need JAX + GPU tests to run on Ampere arch otherwise we will need to skip the tests for all frameworks. Also the current JAX version that runs on github tests does not have dot_product_attention function.

hazemessamm avatar Oct 07 '24 16:10 hazemessamm

I added some conditions for JAX to skip the tests if they were met, what do you think?

hazemessamm avatar Oct 07 '24 17:10 hazemessamm