gpt-fast
gpt-fast copied to clipboard
Incorrect Results with FlexDecoding
@BoyuanFeng
Summary
The KVCache.update() method returns the entire cache buffer including uninitialized (zero) positions, which causes significant numerical errors when using flex_attention. While this doesn't visibly affect discrete token generation (due to argmax), it:
- Produces incorrect attention values (101% relative error)
- Wastes computation on invalid cache positions
- Would cause severe issues for generation with real models esp. over longer contexts
Reproduction
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
batch_size, num_heads, head_dim = 1, 32, 128
max_seq_length = 2048
current_position = 100
# Create query and KV cache
q = torch.randn(batch_size, num_heads, 1, head_dim, device=device, dtype=dtype)
k_cache = torch.zeros(batch_size, num_heads, max_seq_length, head_dim, device=device, dtype=dtype)
v_cache = torch.zeros(batch_size, num_heads, max_seq_length, head_dim, device=device, dtype=dtype)
# Fill only valid positions (0-99)
k_cache[:, :, :current_position] = torch.randn(batch_size, num_heads, current_position, head_dim, device=device, dtype=dtype)
v_cache[:, :, :current_position] = torch.randn(batch_size, num_heads, current_position, head_dim, device=device, dtype=dtype)
# Test 1: Current GPT-Fast approach (full cache)
def offset_causal_mask(b, h, q, kv):
return (q + current_position - 1) >= kv
mask_full = create_block_mask(offset_causal_mask, B=batch_size, H=None, Q_LEN=1, KV_LEN=max_seq_length, device=device)
mask_full.seq_lengths = (1, max_seq_length) # As done in generate.py
output_full = flex_attention(q, k_cache, v_cache, block_mask=mask_full)
# Test 2: Correct approach (sliced cache)
k_sliced = k_cache[:, :, :current_position]
v_sliced = v_cache[:, :, :current_position]
mask_sliced = create_block_mask(causal_mask, B=batch_size, H=None, Q_LEN=1, KV_LEN=current_position, device=device)
mask_sliced.seq_lengths = (1, current_position)
output_sliced = flex_attention(q, k_sliced, v_sliced, block_mask=mask_sliced)
# Compare results
error = (output_full - output_sliced).abs()
print(f"Mean error: {error.mean().item():.6f}")
print(f"Relative error: {(error.mean() / output_sliced.abs().mean() * 100).item():.1f}%")
print(f"Full cache std: {output_full.std().item():.6f}")
print(f"Sliced cache std: {output_sliced.std().item():.6f}")
Results
Mean error: 0.816406
Relative error: 101.0%
Full cache std: 0.152802
Sliced cache std: 1.016770
The full cache approach produces completely different results with 101% relative error!
While slicing the cache fixes the issue, now we have shapes that change every step which is way slower. It probably breaks the flash decoding kernel assumptions.