oslo icon indicating copy to clipboard operation
oslo copied to clipboard

To apply FlashAttention

Open dyanos opened this issue 1 year ago • 1 comments

To apply FlashAttention

  • https://github.com/HazyResearch/flash-attention
  • https://github.com/NVIDIA/cutlass

dyanos avatar Jun 19 '23 11:06 dyanos

To install

pip install flash-attn

To apply

import torch
from flash_attn.flash_attention import FlashMHA

# Replace this with your correct GPU device
device = "cuda:0"

# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
    embed_dim=128, # total channels (= num_heads * head_dim)
    num_heads=8, # number of heads
    device=device,
    dtype=torch.float16,
)

# Run forward pass with dummy data
x = torch.randn(
    (64, 256, 128), # (batch, seqlen, embed_dim)
    device=device,
    dtype=torch.float16
)

output = flash_mha(x)[0]
from flash_attn.flash_attention import FlashAttention

# Create the nn.Module
flash_attention = FlashAttention()

dyanos avatar Jun 19 '23 12:06 dyanos