candle
candle copied to clipboard
Optimize SDPA implementation
Reasoning:
- We use lots of elementwise operations: masked_fill in every layer, elementwise addition and division in our attention implementations.
- GEMM APIs like cuBLAS's gemm provide alpha and beta parameters (fuse the elementwise division!) as well as allowing accumulation into a preallocated tensor (fuse the elementwise addition and save an allocation).
I thought we could apply this to Candle's matmul operator, so I added 2 new methods: matmul_with_alpha_beta which does the matmul C := C + alpha * AxB and is used for attention mask application, and matmul_with_alpha which calculates C := alpha * AxB. I also added a new function to candle_nn called scaled_dot_product_attention which abstracts this into a clean interface and can call FA, too.
I tested this in a few models - it works, and the results seem promising. I have conducted benchmarks:
cargo bench --features ... --package candle-nn
cuda_attention_fast/iter
time: [20.733 ms 20.806 ms 20.888 ms]
thrpt: [765.98 MiB/s 769.00 MiB/s 771.70 MiB/s]
cuda_attention_naive/iter
time: [45.015 ms 45.189 ms 45.388 ms]
thrpt: [352.52 MiB/s 354.07 MiB/s 355.43 MiB/s]
cpu_attention_fast/iter time: [692.47 ms 708.14 ms 723.66 ms]
thrpt: [22.110 MiB/s 22.594 MiB/s 23.106 MiB/s]
cpu_attention_naive/iter
time: [913.23 ms 935.03 ms 957.51 ms]
thrpt: [16.710 MiB/s 17.112 MiB/s 17.520 MiB/s]
mkl_attention_fast/iter time: [326.88 ms 329.80 ms 332.94 ms]
thrpt: [48.057 MiB/s 48.514 MiB/s 48.947 MiB/s]
mkl_attention_naive/iter
time: [530.49 ms 545.89 ms 563.74 ms]
thrpt: [28.382 MiB/s 29.310 MiB/s 30.161 MiB/s]
I have added support to:
- CUDA (2.17x)
- Metal (not tested yet)
- MKL (1.64x)
- Accelerate (not tested yet)
- Plain CPU GEMM (1.4x)