keras
keras copied to clipboard
Flash attention support.
I added support for flash attention for PyTorch.
Let me know what do you think about this current implementation so I can add support for JAX and maybe will try for TF.
Codecov Report
Attention: Patch coverage is 26.31579% with 14 lines in your changes missing coverage. Please review.
Project coverage is 78.85%. Comparing base (
5aa5f88) to head (57e6e56). Report is 2 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #20152 +/- ##
==========================================
+ Coverage 78.81% 78.85% +0.04%
==========================================
Files 512 513 +1
Lines 49063 49250 +187
Branches 9035 9080 +45
==========================================
+ Hits 38668 38837 +169
- Misses 8530 8543 +13
- Partials 1865 1870 +5
| Flag | Coverage Δ | |
|---|---|---|
| keras | 78.71% <26.31%> (+0.04%) |
:arrow_up: |
| keras-jax | 62.36% <21.05%> (+0.10%) |
:arrow_up: |
| keras-numpy | 57.38% <10.52%> (-0.03%) |
:arrow_down: |
| keras-tensorflow | 63.62% <10.52%> (+0.06%) |
:arrow_up: |
| keras-torch | 62.35% <15.78%> (+0.09%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests?
Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests?
In keras/src/ops/nn_test.py. Ops are tested through the op class in e.g. keras/src/ops/nn.py, rather than in a backend specific way.
@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?
@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?
It should be possible to consolidate this into dot_product_attention. That’s how it's implemented in torch, and I've seen a similar approach in jax
(https://github.com/jax-ml/jax/blob/81a31f6adf453b2afc39936e15c15d8ad327bf6e/jax/_src/nn/functions.py#L1037-L1041)
As far as I know, for torch, flash attention is utilized if the conditions are met. For jax, we need to specify implementation="cudnn" to use it.
Very cool -- @hazemessamm can we do that, e.g. by adding a flash_attention argument in dot_product_attention? This makes it quite easy to also add support for JAX ( in addition to PyTorch). For TF I think we can skip support for now.
Awesome work! Thank you.
Thank you, glad I could help.
The test fails on torch + GPU:
FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution.
Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on?
The test fails on torch + GPU:
FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution.
Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on?
I think flash attention in PyTorch does not work with any dtype except float16 and on specific GPUs, I just tested it on H100 GPU and it worked fine but it did not work on T4 GPU on Colab.
I also just found the following functions in PyTorch that we can use to check whether the inputs and the current GPU can use flash attention or not.
import torch
bsz, num_heads, seqlen, head_dim = 1, 2, 10, 16
query = torch.randn((bsz, num_heads, seqlen, head_dim), dtype=torch.float32, device='cuda:0')
params = torch.backends.cuda.SDPAParams(query, query, query, None, 16**-0.5, False)
is_flash_attention_enabled = torch.backends.cuda.can_use_flash_attention(params, False)
print(is_flash_attention_enabled) # Output: False, it will be true if `dtype=torch.float16`
If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend.
Documentation: https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.SDPAParams https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.can_use_flash_attention
If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend.
That sounds great! Then, we can also skip the PyTorch unit test when this check evaluates to False.
I skipped the tests for TensorFlow, NumPy and torch and I just tested JAX on T4 GPU on colab and I got this error: RuntimeError: Require at least Ampere arch to run, so we will need JAX + GPU tests to run on Ampere arch otherwise we will need to skip the tests for all frameworks. Also the current JAX version that runs on github tests does not have dot_product_attention function.
I added some conditions for JAX to skip the tests if they were met, what do you think?