llm.c icon indicating copy to clipboard operation
llm.c copied to clipboard

Speedup `attention_forward_kernel2` by implementing Flash Attention 2 kernel

Open leloykun opened this issue 1 year ago • 2 comments

This speeds up the attention_forward_kernel2 kernel by replacing the implementation with a minimal Flash Attention 2 kernel as can be found in https://github.com/leloykun/flash-hyperbolic-attention-minimal/blob/main/flash_attention_2.cu

Benchmark results on an A100 (80GB)

Attention implementation V1: image

Flash Attention 2 implementation: image

leloykun avatar Apr 11 '24 06:04 leloykun

Very cool! @leloykun could it make sense to maintain both Flash Attention 1 and 2 separately? E.g. Flash Attention 2 as kernel4? I think having multiple versions is great / educational for further kernel development and as a reference.

karpathy avatar Apr 11 '24 15:04 karpathy

Another question: We do eventually want to implement the backward pass for all of these. Should we not leave the variable l intact w.r.t. these future plans?

karpathy avatar Apr 11 '24 15:04 karpathy