gpt-fast
gpt-fast copied to clipboard
Question about large sequence length attention kernels
I really love this project and the accompanying blogpost, so thanks! I've reimplemented some of the inference techniques to speed up an implementation of Whisper that I am using. I had a few questions the attention kernels, as they have been giving me some some performance related issues.
By adding print statements, I can see that during the attention calculation (not including prefilling) the shapes are essentially:
k - (bs, n_heads, max_seq_len, d_head)
v - (bs, n_heads, max_seq_len, d_head)
I understand that max_seq_len is there because of the static KV cache implementation. My understanding is that due to the attention mask, the F.scaled_dot_product_attention combined with torch.compile should be able to tell that it doesn't need to calculate the attention over the entire max_seq_len. In my case however, I've found that the max_seq_len value has a big effect on the inference speed, which suggests to me that the full attention (over the entire max_seq_len context) is being performed on every iteration. This is vastly reduced when using the following context manager, as is done in generate.py:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
If I exclude this, I am seeing a 3x reduction in tok/s. Does this make sense? Or it is a sign I've implemented something wrong? Even with this context manager, I see a significant (50%+) increase in tok/s when I reduce the context length from 4096 to 2048 or 1024.
Thanks in advance. If it helps, here is my cuda graph friendly Whisper implementation using a static KV cache:
class ModelConfig:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
n_vocab: Optional[int] = None
def set_vocab_size(self, vocab_size: int):
self.n_vocab = vocab_size
class KVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.bfloat16,
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val, v_val: [B, H, L, D]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
def sinusoids(
length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2)
)
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
class EncoderAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
self.n_head = n_head
self.d_head = n_state // n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def forward(
self,
xa: Tensor,
):
q = self.query(xa)
k = self.key(xa)
v = self.value(xa)
# Reshape for correct format
batch_size, source_seq_len, _ = k.shape
batch_size, target_seq_len, _ = q.shape
q = q.view(
batch_size, target_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k = k.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
v = v.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
)
wv = wv.transpose(1, 2).reshape(
batch_size,
target_seq_len,
self.n_head * self.d_head,
)
return self.out(wv)
class CrossAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
self.n_head = n_head
self.d_head = n_state // n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_cache: KVCache | None = None
def get_kv(self, xa: torch.Tensor, xa_input_pos: Tensor):
assert self.kv_cache is not None, "No kv_cache"
k = self.key(xa[:, xa_input_pos])
v = self.value(xa[:, xa_input_pos])
# Reshape for correct format
batch_size, source_seq_len, _ = k.shape
k = k.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
v = v.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=xa_input_pos)
return k, v
def forward(
self,
x: Tensor,
xa: Tensor,
xa_input_pos: Tensor,
):
q = self.query(x)
batch_size, target_seq_len, _ = q.shape
q = q.view(
batch_size, target_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k, v = self.get_kv(xa, xa_input_pos)
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
)
wv = wv.transpose(1, 2).reshape(
batch_size,
target_seq_len,
self.n_head * self.d_head,
)
return self.out(wv)
class CausalSelfAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
self.n_head = n_head
self.d_head = n_state // n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_cache: KVCache | None = None
def get_kv(self, x: torch.Tensor, input_pos: torch.Tensor):
# Self attn
k = self.key(x)
v = self.value(x)
# Reshape
batch_size, source_seq_len, _ = k.shape
k = k.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
v = v.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos)
return k, v
def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
):
q = self.query(x)
batch_size, target_seq_len, _ = q.shape
q = q.view(
batch_size, target_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k, v = self.get_kv(x, input_pos=input_pos)
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=mask,
)
# (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
wv = wv.transpose(1, 2).reshape(
batch_size, target_seq_len, self.n_head * self.d_head
)
return self.out(wv)
class EncoderAttentionBlock(nn.Module):
def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False
):
super().__init__()
self.attn = EncoderAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
)
self.mlp_ln = nn.LayerNorm(n_state)
def forward(
self,
xa: Tensor,
):
xa = xa + self.attn(
self.attn_ln(xa),
)
xa = xa + self.mlp(self.mlp_ln(xa))
return xa
class DecoderAttentionBlock(nn.Module):
def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False
):
super().__init__()
self.attn = CausalSelfAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = (
CrossAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
)
self.mlp_ln = nn.LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Tensor,
mask: Optional[Tensor] = None,
x_input_pos: Optional[Tensor] = None,
xa_input_pos: Optional[Tensor] = None,
):
x = x + self.attn(
self.attn_ln(x),
mask=mask,
input_pos=x_input_pos,
)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, xa_input_pos)
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(
n_state, n_state, kernel_size=3, stride=2, padding=1
)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[EncoderAttentionBlock] = nn.ModuleList(
[EncoderAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = nn.LayerNorm(n_state)
def forward(self, xa: Tensor):
xa = F.gelu(self.conv1(xa))
xa = F.gelu(self.conv2(xa))
xa = xa.permute(0, 2, 1)
assert (
xa.shape[1:] == self.positional_embedding.shape
), f"incorrect audio shape: {xa.shape[1:]} != {self.positional_embedding.shape}"
xa = (xa + self.positional_embedding).to(xa.dtype)
for block in self.blocks:
xa = block(xa)
xa = self.ln_post(xa)
return xa
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[DecoderAttentionBlock] = nn.ModuleList(
[
DecoderAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = nn.LayerNorm(n_state)
self.register_buffer("causal_mask", None, persistent=False)
def forward(
self,
x: Tensor,
xa: Tensor,
x_input_pos: Tensor,
xa_input_pos: Tensor,
):
mask = self.causal_mask[None, None, x_input_pos]
x = self.token_embedding(x) + self.positional_embedding[x_input_pos]
for block in self.blocks:
x = block(
x=x,
xa=xa,
mask=mask,
x_input_pos=x_input_pos,
xa_input_pos=xa_input_pos,
)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
def setup_cache(
self,
batch_size,
max_seq_len=4096,
max_audio_len=1500,
):
self.causal_mask = torch.tril(
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
)
# Init cache
for b in self.blocks:
b.attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_seq_len,
n_heads=8,
head_dim=64,
).cuda()
b.cross_attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_audio_len,
n_heads=8,
head_dim=64,
).cuda()
class AmtEncoderDecoder(nn.Module):
def __init__(self, dims: ModelConfig):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
_buff = self.encoder(mel)
return self.decoder(tokens, _buff)
@property
def device(self):
return next(self.parameters()).device
This is a good question! I think there's two components of this question:
- The default FlashAttention kernel is not very performant for decoding. See https://pytorch.org/blog/flash-decoding/ for more detail.
- The attention kernel we generate does not early exit depending on the mask.