oslo
oslo copied to clipboard
To apply FlashAttention
To apply FlashAttention
- https://github.com/HazyResearch/flash-attention
- https://github.com/NVIDIA/cutlass
To install
pip install flash-attn
To apply
import torch
from flash_attn.flash_attention import FlashMHA
# Replace this with your correct GPU device
device = "cuda:0"
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)
# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)
output = flash_mha(x)[0]
from flash_attn.flash_attention import FlashAttention
# Create the nn.Module
flash_attention = FlashAttention()