candle icon indicating copy to clipboard operation
candle copied to clipboard

Optimize SDPA implementation

Open EricLBuehler opened this issue 1 year ago • 0 comments

Reasoning:

  1. We use lots of elementwise operations: masked_fill in every layer, elementwise addition and division in our attention implementations.
  2. GEMM APIs like cuBLAS's gemm provide alpha and beta parameters (fuse the elementwise division!) as well as allowing accumulation into a preallocated tensor (fuse the elementwise addition and save an allocation).

I thought we could apply this to Candle's matmul operator, so I added 2 new methods: matmul_with_alpha_beta which does the matmul C := C + alpha * AxB and is used for attention mask application, and matmul_with_alpha which calculates C := alpha * AxB. I also added a new function to candle_nn called scaled_dot_product_attention which abstracts this into a clean interface and can call FA, too. I tested this in a few models - it works, and the results seem promising. I have conducted benchmarks:

cargo bench --features ... --package candle-nn
cuda_attention_fast/iter
                        time:   [20.733 ms 20.806 ms 20.888 ms]
                        thrpt:  [765.98 MiB/s 769.00 MiB/s 771.70 MiB/s]

cuda_attention_naive/iter
                        time:   [45.015 ms 45.189 ms 45.388 ms]
                        thrpt:  [352.52 MiB/s 354.07 MiB/s 355.43 MiB/s]

cpu_attention_fast/iter time:   [692.47 ms 708.14 ms 723.66 ms]
                        thrpt:  [22.110 MiB/s 22.594 MiB/s 23.106 MiB/s]

cpu_attention_naive/iter
                        time:   [913.23 ms 935.03 ms 957.51 ms]
                        thrpt:  [16.710 MiB/s 17.112 MiB/s 17.520 MiB/s]

mkl_attention_fast/iter time:   [326.88 ms 329.80 ms 332.94 ms]
                        thrpt:  [48.057 MiB/s 48.514 MiB/s 48.947 MiB/s]

mkl_attention_naive/iter
                        time:   [530.49 ms 545.89 ms 563.74 ms]
                        thrpt:  [28.382 MiB/s 29.310 MiB/s 30.161 MiB/s]

I have added support to:

  • CUDA (2.17x)
  • Metal (not tested yet)
  • MKL (1.64x)
  • Accelerate (not tested yet)
  • Plain CPU GEMM (1.4x)

EricLBuehler avatar Aug 05 '24 14:08 EricLBuehler