STT: Add sliding window and attention sink cache
1. Sliding‑window (token‑wise) trimming
Keep only the most‑recent N tokens in each layer’s key/value tensors.
MAX_CACHE_TOKENS = 256 # ≃ 12 s of speech for tiny/small
def trim_past_kv(past_kv, keep=MAX_CACHE_TOKENS):
"""past_kv is an EncoderDecoderCache.
Returns a view that keeps only the last keep positions."""
for layer in range(len(past_kv.self_attention_cache.key_cache)):
k = past_kv.self_attention_cache.key_cache[layer]
v = past_kv.self_attention_cache.value_cache[layer]
# k, v: (batch, heads, seq_len, head_dim)
past_kv.self_attention_cache.key_cache[layer] = k[:, :, -keep:, :]
past_kv.self_attention_cache.value_cache[layer] = v[:, :, -keep:, :]
# Tell the cache how many positions we just dropped
past_kv.offset -= max(0, past_kv.seq_len - keep)
return past_kv
Insert this right after you append each new token:
past_kv = trim_past_kv(past_kv)
Memory saved –
2 × L × H × D × keep × bytes_per_elem
Whisper‑small fp16 ⇒2 × 12 × 6 × 64 × 256 × 2 ≈ 3 MiBinstead of 6 MiB with a 512‑token cache.
Accuracy hit – Up to ~1 s before the window edge you’ll almost never notice; beyond that you might lose punctuation or long‑range re‑writes. With keep=256 most users report no audible quality drop in interactive dictation.
2. Sliding‑window with attention sink
Keep the first SINK_SIZE global “sink” tokens plus the newest WINDOW_SIZE tokens. These sink positions act as a stable anchor and restore most of the quality you’d lose with a pure sliding window — at essentially the same memory cost.
SINK_SIZE = 32 # first tokens kept forever
WINDOW_SIZE = 256 # recent tokens kept as usual
def trim_kv_with_sink(past_kv, sink=SINK_SIZE, window=WINDOW_SIZE):
"""Return a view that keeps <sink> head tokens and the last <window> ones."""
seq_len = past_kv.get_seq_length()
if seq_len <= sink + window:
return past_kv # nothing to trim yet
for layer in range(len(past_kv.self_attention_cache.key_cache)):
k = past_kv.self_attention_cache.key_cache[layer]
v = past_kv.self_attention_cache.value_cache[layer]
past_kv.self_attention_cache.key_cache[layer] = torch.cat(
[k[:, :, :sink, :], k[:, :, -window:, :]], dim=2)
past_kv.self_attention_cache.value_cache[layer] = torch.cat(
[v[:, :, :sink, :], v[:, :, -window:, :]], dim=2)
past_kv.offset += seq_len - (sink + window) # align cache_position
return past_kv
After generating each new token:
past_kv = trim_kv_with_sink(past_kv)
Memory saved –
(SINK_SIZE + WINDOW_SIZE) × 2 × L × H × D × bytes_per_elem.
With the values above on Whisper‑small fp16 this is ≈ 3.4 MiB.
3. Layer‑wise half‑caching
Keep all tokens, but only in every second decoder layer:
def drop_even_layers(past_kv):
for layer in range(0, len(past_kv.self_attention_cache.key_cache), 2):
past_kv.self_attention_cache.key_cache[layer] = None
past_kv.self_attention_cache.value_cache[layer] = None
return past_kv
When the model reaches a layer whose cache is None it just recomputes K/V on the fly. This roughly halves memory and keeps long‑range context, but adds ~8–12 % extra compute because the skipped layers do some re‑work.
4. Which one should you pick?
| Goal | Recommended trick |
|---|---|
| Lowest latency (GPU) | Sliding‑window, keep≈128–256 |
| Tiny VRAM (≤ 4 GB) | Window keep≈128 and drop even layers |
| CPU‑only laptop | Layer‑wise drop (compute‑bound > memory‑bound) |
| No quality loss allowed | Keep sink tokens and leave first 2–4 decoder layers fully cached |
Implementation notes & gotchas
-
Adjust
cache_position– after you slice tokens you must subtract the number you removed (past_kv.offsetabove) so the causal mask lines up. -
When you restart generation (new sentence, VAD gap, …) just set
past_kv = Nonerather than trying to “flush” it. -
All tricks work for
attn_implementation="sdpa"; they’re just plain tensor views.
TL;DR
Yes, trim away. A 256‑token sliding window is the cleanest 50 % memory win with negligible impact; adding sink tokens restores accuracy, and drop‑every‑other‑layer buys another ~25 %, paid for in a small speed tax. The core code is literally a single tensor[:, :, -N:] or torch.cat() line.