onnxruntime
onnxruntime copied to clipboard
Implement FlashAttention for CPU
Description
Implement FlashAttention and FlashAttention-2 for MultiHeadAttention on CPU.
Motivation and Context
Accelerate the execution of MultiHeadAttention.
Current performance: 10ms vs 16ms (com.microsoft.MultiHeadAttention) on my Linux machine and 10ms vs 38ms (com.microsoft.MultiHeadAttention) on my Windows machine. May need further optimizations.
Test failing: MultiHeadAttentionTest.CrossAttention_DiffSequenceLengths
Edit: passed
Environment Variables: ORT_DISABLE_FLASH_ATTENTION=0
format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.59 0.17 CPU:Flash Q,K,V False 1 256 32 128 2.74 0.39 CPU:Flash Q,K,V False 1 512 32 128 8.28 0.52 CPU:Flash Q,K,V False 1 1024 32 128 26.43 0.65 CPU:Flash Q,K,V False 1 2048 32 128 88.92 0.77 CPU:Flash Q,K,V False 1 4096 8 40 36.26 0.59 CPU:Flash Q,K,V False 1 4096 8 80 54.36 0.79 CPU:Flash Q,K,V False 1 4096 8 160 99.28 0.87 CPU:Flash Q,K,V False 4 4096 8 40 144.85 0.59 CPU:Flash Q,K,V False 4 4096 8 80 217.08 0.79 CPU:Flash Q,K,V False 4 4096 8 160 400.06 0.86 CPU:Flash Q,K,V False 1 16384 8 40 570.16 0.60 CPU:Flash Q,K,V False 1 16384 8 80 854.11 0.80 CPU:Flash Q,K,V False 1 16384 8 160 1511.06 0.91 CPU:Flash Q,K,V False 128 128 12 64 29.84 0.22 CPU:Flash Q,K,V False 64 128 12 64 14.82 0.22 CPU:Flash Q,K,V False 128 384 12 64 131.07 0.44 CPU:Flash Q,K,V False 64 384 12 64 65.70 0.44 CPU:Flash Q,K,V False 128 512 12 64 203.86 0.51 CPU:Flash Q,K,V False 64 512 12 64 99.83 0.52 CPU:Flash Q,K,V False 4 2048 32 128 350.01 0.79 CPU:Flash Q,K,V False 4 4096 32 128 1278.42 0.86 CPU:Flash Q,K,V False 8 2048 32 128 698.98 0.79 CPU:Flash Q,K,V False 8 4096 32 128 2547.00 0.86 CPU:Flash
Environment Variables: ORT_DISABLE_FLASH_ATTENTION=1
format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.43 0.19 CPU:Unfused Q,K,V False 1 256 32 128 3.24 0.33 CPU:Unfused Q,K,V False 1 512 32 128 11.26 0.38 CPU:Unfused Q,K,V False 1 1024 32 128 36.88 0.47 CPU:Unfused Q,K,V False 1 2048 32 128 106.25 0.65 CPU:Unfused Q,K,V False 1 4096 8 40 49.43 0.43 CPU:Unfused Q,K,V False 1 4096 8 80 75.99 0.57 CPU:Unfused Q,K,V False 1 4096 8 160 137.47 0.62 CPU:Unfused Q,K,V False 4 4096 8 40 194.25 0.44 CPU:Unfused Q,K,V False 4 4096 8 80 298.62 0.58 CPU:Unfused Q,K,V False 4 4096 8 160 540.00 0.64 CPU:Unfused Q,K,V False 1 16384 8 40 962.66 0.36 CPU:Unfused Q,K,V False 1 16384 8 80 1389.89 0.49 CPU:Unfused Q,K,V False 1 16384 8 160 2605.56 0.53 CPU:Unfused Q,K,V False 128 128 12 64 33.08 0.19 CPU:Unfused Q,K,V False 64 128 12 64 16.26 0.20 CPU:Unfused Q,K,V False 128 384 12 64 149.92 0.39 CPU:Unfused Q,K,V False 64 384 12 64 75.20 0.39 CPU:Unfused Q,K,V False 128 512 12 64 234.68 0.44 CPU:Unfused Q,K,V False 64 512 12 64 117.20 0.44 CPU:Unfused Q,K,V False 4 2048 32 128 409.42 0.67 CPU:Unfused Q,K,V False 4 4096 32 128 1561.20 0.70 CPU:Unfused Q,K,V False 8 2048 32 128 814.60 0.67 CPU:Unfused Q,K,V False 8 4096 32 128 3112.91 0.71 CPU:Unfused
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline
/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline
/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline
Azure Pipelines successfully started running 3 pipeline(s).
Azure Pipelines successfully started running 10 pipeline(s).
Azure Pipelines successfully started running 10 pipeline(s).
Please fix python format by running lintrunner at the root like
pip install -r requirements-lintrunner.txt
pip install lintrunner
lintrunner init
lintrunner -a
Enabled FlashAttention by default, tested on Azure D16ds_v5
Environment Variables:
format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.49 0.18 CPU:Flash Q,K,V False 1 256 32 128 2.63 0.41 CPU:Flash Q,K,V False 1 512 32 128 8.38 0.51 CPU:Flash Q,K,V False 1 1024 32 128 26.84 0.64 CPU:Flash Q,K,V False 1 2048 32 128 89.41 0.77 CPU:Flash
Environment Variables: ORT_DISABLE_FLASH_ATTENTION=1
format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.50 0.18 CPU:Unfused Q,K,V False 1 256 32 128 3.30 0.33 CPU:Unfused Q,K,V False 1 512 32 128 11.04 0.39 CPU:Unfused Q,K,V False 1 1024 32 128 37.28 0.46 CPU:Unfused Q,K,V False 1 2048 32 128 135.93 0.51 CPU:Unfused
Environment Variables: ORT_DISABLE_FLASH_ATTENTION=0
format causal batch seqlen heads h_dim ms TFLOPS kernel Q,K,V False 1 128 32 128 1.67 0.16 CPU:Flash Q,K,V False 1 256 32 128 2.79 0.38 CPU:Flash Q,K,V False 1 512 32 128 8.33 0.52 CPU:Flash Q,K,V False 1 1024 32 128 26.61 0.65 CPU:Flash Q,K,V False 1 2048 32 128 88.45 0.78 CPU:Flash
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline
/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline
/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline
Azure Pipelines successfully started running 3 pipeline(s).
Azure Pipelines successfully started running 10 pipeline(s).
Azure Pipelines successfully started running 10 pipeline(s).
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline
/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline
/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline
Azure Pipelines successfully started running 3 pipeline(s).
Azure Pipelines successfully started running 10 pipeline(s).
Azure Pipelines successfully started running 10 pipeline(s).
Need run lintrunner since python format pipeline failed.
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline
/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline
/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline
Pipelines were unable to run due to time out waiting for the pull request to finish merging.
Pipelines were unable to run due to time out waiting for the pull request to finish merging.
Pipelines were unable to run due to time out waiting for the pull request to finish merging.
/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline