llm.c icon indicating copy to clipboard operation
llm.c copied to clipboard

feat(attention_forward.cu): Gentle introduction to CuTe(cutlass)

Open FeSens opened this issue 10 months ago • 9 comments

This is a very, very gentle introduction to Flash Attention 2 with CuTe (Cutlass v3). It's gentle because it's not finished.

What I've got so far:

  • Work partitioned between Query block, Batch, and Head (as is Flash attention 2 to the best of my knowledge);
  • Efficiently copying tiles of Q and K using CuTe;
  • Using CuTe primitives to do matrix multiply (gemm) and scalar multiplication (axpby)

I'm converting this to a full PR because it may help people who want to start working with CuTe, and it may be less scary than jumping headfirst into a full flash attention implementation.

I will be hanging out on the CudaMode discord if anyone wants to pair or has a better understanding of CuTe and wants to help =).

FeSens avatar Apr 23 '24 20:04 FeSens

Ok, I've figured out that I need to make the layouts of sQ and sK static for now if I want to use gemm(sQ, sK, ...).

FeSens avatar Apr 23 '24 22:04 FeSens

Is it compiling now?

ericauld avatar Apr 24 '24 01:04 ericauld

It is! Not to the performance I wish it had, but it definitely is compiling. Thanks for the Cutlass class @ericauld ! Any tips on how to speed up this kernel?

FeSens avatar Apr 24 '24 07:04 FeSens

@FeSens can you post what kind of perf you're seeing for this?

karpathy avatar Apr 24 '24 18:04 karpathy

It's still far from Cublas. I'm working on getting the proper thread partitions before advancing on the other parts necessary for flash attention.

attention_query_key_kernel2 | Best time 54.376431 ms
cublasSgemmStridedBatched   | Best time 4.497695 ms

Once this part is 90% of the speed of cublas then we will probably see improvements after implementing the missing parts.

FeSens avatar Apr 24 '24 19:04 FeSens

This is at 65% of the speed of Cublas cublasSgemmStridedBatched

FeSens avatar Apr 25 '24 07:04 FeSens

I believe you're running into the same trap I did. We currently don't enable tensor cores in this dev file, so cublasSgemmStridedBatched will be much slower than it is going to be in the real model. So that 65% is actually still a lower number.

ngc92 avatar Apr 25 '24 08:04 ngc92

@ngc92 This is because of the variable types we are using right? Or do we need to turn on a flag explicitly? My plan today is to make the shapes going to the gemm() template function some proper shape that has tensor operation support and benefit from that when we change the variable types.

FeSens avatar Apr 25 '24 17:04 FeSens

It's a flag that needs to be set for cuBLAS.

ngc92 avatar Apr 25 '24 17:04 ngc92