llm.c
llm.c copied to clipboard
feat(attention_forward.cu): Gentle introduction to CuTe(cutlass)
This is a very, very gentle introduction to Flash Attention 2 with CuTe (Cutlass v3). It's gentle because it's not finished.
What I've got so far:
- Work partitioned between Query block, Batch, and Head (as is Flash attention 2 to the best of my knowledge);
- Efficiently copying tiles of Q and K using CuTe;
- Using CuTe primitives to do matrix multiply (gemm) and scalar multiplication (axpby)
I'm converting this to a full PR because it may help people who want to start working with CuTe, and it may be less scary than jumping headfirst into a full flash attention implementation.
I will be hanging out on the CudaMode discord if anyone wants to pair or has a better understanding of CuTe and wants to help =).
Ok, I've figured out that I need to make the layouts of sQ and sK static for now if I want to use gemm(sQ, sK, ...).
Is it compiling now?
It is! Not to the performance I wish it had, but it definitely is compiling. Thanks for the Cutlass class @ericauld ! Any tips on how to speed up this kernel?
@FeSens can you post what kind of perf you're seeing for this?
It's still far from Cublas. I'm working on getting the proper thread partitions before advancing on the other parts necessary for flash attention.
attention_query_key_kernel2 | Best time 54.376431 ms
cublasSgemmStridedBatched | Best time 4.497695 ms
Once this part is 90% of the speed of cublas then we will probably see improvements after implementing the missing parts.
This is at 65% of the speed of Cublas cublasSgemmStridedBatched
I believe you're running into the same trap I did. We currently don't enable tensor cores in this dev file, so cublasSgemmStridedBatched
will be much slower than it is going to be in the real model. So that 65% is actually still a lower number.
@ngc92 This is because of the variable types we are using right? Or do we need to turn on a flag explicitly?
My plan today is to make the shapes going to the gemm()
template function some proper shape that has tensor operation support and benefit from that when we change the variable types.
It's a flag that needs to be set for cuBLAS.