onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Implement FlashAttention for CPU

Open duanqn opened this issue 1 year ago • 35 comments

Description

Implement FlashAttention and FlashAttention-2 for MultiHeadAttention on CPU.

Motivation and Context

Accelerate the execution of MultiHeadAttention.

Current performance: 10ms vs 16ms (com.microsoft.MultiHeadAttention) on my Linux machine and 10ms vs 38ms (com.microsoft.MultiHeadAttention) on my Windows machine. May need further optimizations.

duanqn avatar May 24 '24 08:05 duanqn

Test failing: MultiHeadAttentionTest.CrossAttention_DiffSequenceLengths

Edit: passed

duanqn avatar Jun 19 '24 09:06 duanqn

Environment Variables: ORT_DISABLE_FLASH_ATTENTION=0

format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.59 0.17 CPU:Flash Q,K,V False 1 256 32 128 2.74 0.39 CPU:Flash Q,K,V False 1 512 32 128 8.28 0.52 CPU:Flash Q,K,V False 1 1024 32 128 26.43 0.65 CPU:Flash Q,K,V False 1 2048 32 128 88.92 0.77 CPU:Flash Q,K,V False 1 4096 8 40 36.26 0.59 CPU:Flash Q,K,V False 1 4096 8 80 54.36 0.79 CPU:Flash Q,K,V False 1 4096 8 160 99.28 0.87 CPU:Flash Q,K,V False 4 4096 8 40 144.85 0.59 CPU:Flash Q,K,V False 4 4096 8 80 217.08 0.79 CPU:Flash Q,K,V False 4 4096 8 160 400.06 0.86 CPU:Flash Q,K,V False 1 16384 8 40 570.16 0.60 CPU:Flash Q,K,V False 1 16384 8 80 854.11 0.80 CPU:Flash Q,K,V False 1 16384 8 160 1511.06 0.91 CPU:Flash Q,K,V False 128 128 12 64 29.84 0.22 CPU:Flash Q,K,V False 64 128 12 64 14.82 0.22 CPU:Flash Q,K,V False 128 384 12 64 131.07 0.44 CPU:Flash Q,K,V False 64 384 12 64 65.70 0.44 CPU:Flash Q,K,V False 128 512 12 64 203.86 0.51 CPU:Flash Q,K,V False 64 512 12 64 99.83 0.52 CPU:Flash Q,K,V False 4 2048 32 128 350.01 0.79 CPU:Flash Q,K,V False 4 4096 32 128 1278.42 0.86 CPU:Flash Q,K,V False 8 2048 32 128 698.98 0.79 CPU:Flash Q,K,V False 8 4096 32 128 2547.00 0.86 CPU:Flash

Environment Variables: ORT_DISABLE_FLASH_ATTENTION=1

format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.43 0.19 CPU:Unfused Q,K,V False 1 256 32 128 3.24 0.33 CPU:Unfused Q,K,V False 1 512 32 128 11.26 0.38 CPU:Unfused Q,K,V False 1 1024 32 128 36.88 0.47 CPU:Unfused Q,K,V False 1 2048 32 128 106.25 0.65 CPU:Unfused Q,K,V False 1 4096 8 40 49.43 0.43 CPU:Unfused Q,K,V False 1 4096 8 80 75.99 0.57 CPU:Unfused Q,K,V False 1 4096 8 160 137.47 0.62 CPU:Unfused Q,K,V False 4 4096 8 40 194.25 0.44 CPU:Unfused Q,K,V False 4 4096 8 80 298.62 0.58 CPU:Unfused Q,K,V False 4 4096 8 160 540.00 0.64 CPU:Unfused Q,K,V False 1 16384 8 40 962.66 0.36 CPU:Unfused Q,K,V False 1 16384 8 80 1389.89 0.49 CPU:Unfused Q,K,V False 1 16384 8 160 2605.56 0.53 CPU:Unfused Q,K,V False 128 128 12 64 33.08 0.19 CPU:Unfused Q,K,V False 64 128 12 64 16.26 0.20 CPU:Unfused Q,K,V False 128 384 12 64 149.92 0.39 CPU:Unfused Q,K,V False 64 384 12 64 75.20 0.39 CPU:Unfused Q,K,V False 128 512 12 64 234.68 0.44 CPU:Unfused Q,K,V False 64 512 12 64 117.20 0.44 CPU:Unfused Q,K,V False 4 2048 32 128 409.42 0.67 CPU:Unfused Q,K,V False 4 4096 32 128 1561.20 0.70 CPU:Unfused Q,K,V False 8 2048 32 128 814.60 0.67 CPU:Unfused Q,K,V False 8 4096 32 128 3112.91 0.71 CPU:Unfused

duanqn avatar Jun 21 '24 10:06 duanqn

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

tianleiwu avatar Jun 21 '24 21:06 tianleiwu

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

tianleiwu avatar Jun 21 '24 21:06 tianleiwu

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

tianleiwu avatar Jun 21 '24 21:06 tianleiwu

Azure Pipelines successfully started running 3 pipeline(s).

azure-pipelines[bot] avatar Jun 21 '24 21:06 azure-pipelines[bot]

Azure Pipelines successfully started running 10 pipeline(s).

azure-pipelines[bot] avatar Jun 21 '24 21:06 azure-pipelines[bot]

Azure Pipelines successfully started running 10 pipeline(s).

azure-pipelines[bot] avatar Jun 21 '24 21:06 azure-pipelines[bot]

Please fix python format by running lintrunner at the root like

pip install -r requirements-lintrunner.txt
pip install lintrunner
lintrunner init
lintrunner -a

tianleiwu avatar Jun 21 '24 21:06 tianleiwu

Enabled FlashAttention by default, tested on Azure D16ds_v5

Environment Variables:

format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.49 0.18 CPU:Flash Q,K,V False 1 256 32 128 2.63 0.41 CPU:Flash Q,K,V False 1 512 32 128 8.38 0.51 CPU:Flash Q,K,V False 1 1024 32 128 26.84 0.64 CPU:Flash Q,K,V False 1 2048 32 128 89.41 0.77 CPU:Flash

Environment Variables: ORT_DISABLE_FLASH_ATTENTION=1

format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.50 0.18 CPU:Unfused Q,K,V False 1 256 32 128 3.30 0.33 CPU:Unfused Q,K,V False 1 512 32 128 11.04 0.39 CPU:Unfused Q,K,V False 1 1024 32 128 37.28 0.46 CPU:Unfused Q,K,V False 1 2048 32 128 135.93 0.51 CPU:Unfused

Environment Variables: ORT_DISABLE_FLASH_ATTENTION=0

format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.67 0.16 CPU:Flash Q,K,V False 1 256 32 128 2.79 0.38 CPU:Flash Q,K,V False 1 512 32 128 8.33 0.52 CPU:Flash Q,K,V False 1 1024 32 128 26.61 0.65 CPU:Flash Q,K,V False 1 2048 32 128 88.45 0.78 CPU:Flash

duanqn avatar Jun 24 '24 07:06 duanqn

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

tianleiwu avatar Jun 24 '24 16:06 tianleiwu

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

tianleiwu avatar Jun 24 '24 16:06 tianleiwu

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

tianleiwu avatar Jun 24 '24 16:06 tianleiwu

Azure Pipelines successfully started running 3 pipeline(s).

azure-pipelines[bot] avatar Jun 24 '24 16:06 azure-pipelines[bot]

Azure Pipelines successfully started running 10 pipeline(s).

azure-pipelines[bot] avatar Jun 24 '24 16:06 azure-pipelines[bot]

Azure Pipelines successfully started running 10 pipeline(s).

azure-pipelines[bot] avatar Jun 24 '24 16:06 azure-pipelines[bot]

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

tianleiwu avatar Jun 25 '24 17:06 tianleiwu

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

tianleiwu avatar Jun 25 '24 17:06 tianleiwu

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

tianleiwu avatar Jun 25 '24 17:06 tianleiwu

Azure Pipelines successfully started running 3 pipeline(s).

azure-pipelines[bot] avatar Jun 25 '24 17:06 azure-pipelines[bot]

Azure Pipelines successfully started running 10 pipeline(s).

azure-pipelines[bot] avatar Jun 25 '24 17:06 azure-pipelines[bot]

Azure Pipelines successfully started running 10 pipeline(s).

azure-pipelines[bot] avatar Jun 25 '24 17:06 azure-pipelines[bot]

Need run lintrunner since python format pipeline failed.

tianleiwu avatar Jun 27 '24 07:06 tianleiwu

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

tianleiwu avatar Jun 28 '24 17:06 tianleiwu

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

tianleiwu avatar Jun 28 '24 17:06 tianleiwu

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

tianleiwu avatar Jun 28 '24 17:06 tianleiwu

Pipelines were unable to run due to time out waiting for the pull request to finish merging.

azure-pipelines[bot] avatar Jun 28 '24 17:06 azure-pipelines[bot]

Pipelines were unable to run due to time out waiting for the pull request to finish merging.

azure-pipelines[bot] avatar Jun 28 '24 17:06 azure-pipelines[bot]

Pipelines were unable to run due to time out waiting for the pull request to finish merging.

azure-pipelines[bot] avatar Jun 28 '24 17:06 azure-pipelines[bot]

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

tianleiwu avatar Jun 28 '24 19:06 tianleiwu