mlx-audio icon indicating copy to clipboard operation
mlx-audio copied to clipboard

STT: Add sliding window and attention sink cache

Open Blaizzy opened this issue 9 months ago • 0 comments

1. Sliding‑window (token‑wise) trimming

Keep only the most‑recent N tokens in each layer’s key/value tensors.

MAX_CACHE_TOKENS = 256          #  ≃ 12 s of speech for tiny/small

def trim_past_kv(past_kv, keep=MAX_CACHE_TOKENS): """past_kv is an EncoderDecoderCache. Returns a view that keeps only the last keep positions.""" for layer in range(len(past_kv.self_attention_cache.key_cache)): k = past_kv.self_attention_cache.key_cache[layer] v = past_kv.self_attention_cache.value_cache[layer] # k, v: (batch, heads, seq_len, head_dim) past_kv.self_attention_cache.key_cache[layer] = k[:, :, -keep:, :] past_kv.self_attention_cache.value_cache[layer] = v[:, :, -keep:, :]

# Tell the cache how many positions we just dropped
past_kv.offset -= max(0, past_kv.seq_len - keep)
return past_kv

Insert this right after you append each new token:

past_kv = trim_past_kv(past_kv)

Memory saved2 × L × H × D × keep × bytes_per_elem
Whisper‑small fp16 ⇒ 2 × 12 × 6 × 64 × 256 × 2 ≈ 3 MiB instead of 6 MiB with a 512‑token cache.

Accuracy hit – Up to ~1 s before the window edge you’ll almost never notice; beyond that you might lose punctuation or long‑range re‑writes. With keep=256 most users report no audible quality drop in interactive dictation.


2. Sliding‑window with attention sink

Keep the first SINK_SIZE global “sink” tokens plus the newest WINDOW_SIZE tokens. These sink positions act as a stable anchor and restore most of the quality you’d lose with a pure sliding window — at essentially the same memory cost.

SINK_SIZE   = 32    # first tokens kept forever
WINDOW_SIZE = 256   # recent tokens kept as usual

def trim_kv_with_sink(past_kv, sink=SINK_SIZE, window=WINDOW_SIZE): """Return a view that keeps <sink> head tokens and the last <window> ones.""" seq_len = past_kv.get_seq_length() if seq_len <= sink + window: return past_kv # nothing to trim yet

for layer in range(len(past_kv.self_attention_cache.key_cache)):
    k = past_kv.self_attention_cache.key_cache[layer]
    v = past_kv.self_attention_cache.value_cache[layer]
    past_kv.self_attention_cache.key_cache[layer] = torch.cat(
        [k[:, :, :sink, :], k[:, :, -window:, :]], dim=2)
    past_kv.self_attention_cache.value_cache[layer] = torch.cat(
        [v[:, :, :sink, :], v[:, :, -window:, :]], dim=2)

past_kv.offset += seq_len - (sink + window)  # align cache_position
return past_kv

After generating each new token:

past_kv = trim_kv_with_sink(past_kv)

Memory saved(SINK_SIZE + WINDOW_SIZE) × 2 × L × H × D × bytes_per_elem.
With the values above on Whisper‑small fp16 this is ≈ 3.4 MiB.


3. Layer‑wise half‑caching

Keep all tokens, but only in every second decoder layer:

def drop_even_layers(past_kv):
    for layer in range(0, len(past_kv.self_attention_cache.key_cache), 2):
        past_kv.self_attention_cache.key_cache[layer]   = None
        past_kv.self_attention_cache.value_cache[layer] = None
    return past_kv

When the model reaches a layer whose cache is None it just recomputes K/V on the fly. This roughly halves memory and keeps long‑range context, but adds ~8–12 % extra compute because the skipped layers do some re‑work.


4. Which one should you pick?

Goal Recommended trick
Lowest latency (GPU) Sliding‑window, keep≈128–256
Tiny VRAM (≤ 4 GB) Window keep≈128 and drop even layers
CPU‑only laptop Layer‑wise drop (compute‑bound > memory‑bound)
No quality loss allowed Keep sink tokens and leave first 2–4 decoder layers fully cached

Implementation notes & gotchas

  1. Adjust cache_position – after you slice tokens you must subtract the number you removed (past_kv.offset above) so the causal mask lines up.

  2. When you restart generation (new sentence, VAD gap, …) just set past_kv = None rather than trying to “flush” it.

  3. All tricks work for attn_implementation="sdpa"; they’re just plain tensor views.


TL;DR

Yes, trim away. A 256‑token sliding window is the cleanest 50 % memory win with negligible impact; adding sink tokens restores accuracy, and drop‑every‑other‑layer buys another ~25 %, paid for in a small speed tax. The core code is literally a single tensor[:, :, -N:] or torch.cat() line.

Blaizzy avatar May 09 '25 14:05 Blaizzy