Fixes StaticCache Crashes
What does this PR do?
Fixes #42454
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @zucchini-nlp @Rocketknight1 @mobicham
Benchmarking script -
import torch
import time
import numpy as np
from transformers import WhisperForConditionalGeneration
from transformers.cache_utils import StaticCache, EncoderDecoderCache
MODEL_ID = "openai/whisper-tiny"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_BATCH_SIZE = 64
TEST_BATCHES = [1, 8, 32, 64]
SEQ_LEN = 128
WARMUP = 10
REPEATS = 50
def load_model():
model = WhisperForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
attn_implementation="sdpa",
).to(DEVICE)
model.eval()
return model
def run_benchmark(model, batch_size, cache_cap, tag):
decoder = model.model.decoder
cache_len = SEQ_LEN + 10
enc_len = SEQ_LEN
self_cache = StaticCache(
config=decoder.config,
max_batch_size=cache_cap,
max_cache_len=cache_len,
device=DEVICE,
dtype=DTYPE,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=cache_cap,
max_cache_len=enc_len,
device=DEVICE,
dtype=DTYPE,
)
kv_cache = EncoderDecoderCache(self_cache, cross_cache)
input_ids = torch.randint(0, 1000, (batch_size, SEQ_LEN), device=DEVICE)
encoder_states = torch.randn(batch_size, enc_len, model.config.d_model, device=DEVICE, dtype=DTYPE)
cache_pos = torch.arange(SEQ_LEN, device=DEVICE)
for _ in range(WARMUP):
kv_cache.reset()
with torch.no_grad():
decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_states,
past_key_values=kv_cache,
cache_position=cache_pos,
use_cache=True,
)
# actual runs
start = [torch.cuda.Event(enable_timing=True) for _ in range(REPEATS)]
end = [torch.cuda.Event(enable_timing=True) for _ in range(REPEATS)]
torch.cuda.synchronize()
for i in range(REPEATS):
kv_cache.reset()
start[i].record()
with torch.no_grad():
decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_states,
past_key_values=kv_cache,
cache_position=cache_pos,
use_cache=True,
)
end[i].record()
torch.cuda.synchronize()
times = [s.elapsed_time(e) for s, e in zip(start, end)]
avg = float(np.mean(times))
print(f"[{tag}] batch={batch_size:2d} cache_cap={cache_cap:2d} latency={avg:.2f} ms")
return avg
model = load_model()
results = []
for bs in TEST_BATCHES:
base = run_benchmark(model, batch_size=bs, cache_cap=bs, tag="BASELINE")
sliced = run_benchmark(model, batch_size=bs, cache_cap=MAX_BATCH_SIZE, tag="SLICED")
diff = (sliced - base) / base * 100
results.append((bs, base, sliced, diff))
print("Summary:")
print(f"{'Batch':<8} | {'Baseline (ms)':<15} | {'Sliced (ms)':<15} | Diff (%)")
for bs, base, sliced, diff in results:
print(f"{bs:<8} | {base:<15.2f} | {sliced:<15.2f} | {diff:+.2f}%")
@zucchini-nlp This doesn't have the max_batch_size as you mentioned . If it's something that I should add , please lmk .
I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo
So basically I should add max_batch_size to the __init__ method of StaticCache and then in StaticLayer modify the lazy_initialization to use max_batch_size .
And also change line 1837 from src/transformers/generation/utils.py to as seen in #37394
or cache_to_check.max_batch_size < batch_size
Yep, and a small test as well
hi @zucchini-nlp I'm still stuck on this. I’ve been testing with torch.compile and it works fine with GPT-2, but does not work with whisper small ,I’m not sure what I’m missing tbh .
If you have any pointers on what I should check or tweak, I’d really appreciate it.
Thanks a lot and sorry for the trouble
@i3hz I will also do some debugging next week
@i3hz I will also do some debugging next week
Thanks a lot The main issue still lies within torch.compile as without it the model is working
@i3hz I tried the slicing solution but it throws an attention error even without torch.compile:
RuntimeError: The size of tensor a (4) must match the size of tensor b (8) at non-singleton dimension 0
Something is strange, this works:
for bs in [8, 4, 2, 1]:
past_key_values = create_cache(max_batch_size)
...
but when the cache is allocated only once, it throws that error:
past_key_values = create_cache(max_batch_size)
for bs in [8, 4, 2, 1]:
...
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder(model, labels, encoder_outputs, past_key_values, prefill: bool):
seq_length = labels.shape[-1]
if(prefill):
cache_position = torch.arange(seq_length, device=device)
else:
cache_position = torch.tensor([seq_length], device=device)
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
################################################
from transformers import cache_utils
from typing import Optional, Any
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.is_initialized:
self.lazy_initialization(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"
print(f"{batch_size}:{self.max_batch_size}")
k_out = self.keys[:batch_size]
v_out = self.values[:batch_size]
# Update the cache
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
cache_utils.StaticLayer.update = update
################################################
def create_cache(max_batch_size):
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
return EncoderDecoderCache(self_cache, cross_cache)
#torch._dynamo.config.capture_scalar_outputs = True
#run_encoder = torch.compile(run_encoder, mode='reduce-overhead', fullgraph=True)
max_batch_size = 8
past_key_values = create_cache(max_batch_size)
for bs in [8, 4, 2, 1]:
assert bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * bs, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs.clone(), past_key_values_out, prefill=False)
So this issue is this part https://github.com/huggingface/transformers/blob/7f5c20945a97ed960eb85d96b93c89f33772fd20/src/transformers/models/whisper/modeling_whisper.py#L329-L330
if you replace it with this, it works.
key_states = past_key_values.layers[self.layer_idx].keys[:bsz]
value_states = past_key_values.layers[self.layer_idx].values[:bsz]
However, the problem is that we can't do this for every modeling file separately. I guess the solution is to do something with self.keys and self.values, like this, it works with torch.compile:
@zucchini-nlp what do you think?
class StaticLayer(CacheLayerMixin):
"""
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
"""
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int):
super().__init__()
self.max_cache_len = max_cache_len
def lazy_initialization(self, key_states: torch.Tensor):
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.keys, self.values = self.keys_, self.values_
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
# prefill explicitly, but this should be avoided!)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_)
torch._dynamo.mark_static_address(self.values_)
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states)
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"
self.keys = self.keys_[:batch_size]
self.values = self.values_[:batch_size]
# Update the cache
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.max_cache_len
@i3hz I tried the slicing solution but it throws an attention error even without torch.compile:
You are right it's not working :c .
I switched models to gpt2 and it does work (as I mentioned before) . I really don't know why that's happening , is it because gpt2 does not use EncoderDecoderCache ?
Or I think the problem probably is that in the testing script for whisper we only ran the encoding logic , whereas now we're also trying the decoding logic (which is an oversight on my part ,sorry ) so your suggestion about the self.keys_ and self.values_ might be the correct fix .
My reproduction script which uses gpt2 instead of whisper if you need it (which does successfully run )
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
device = "cuda"
model_id = "openai-community/gpt2"
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.float16,
attn_implementation="sdpa",
ignore_mismatched_sizes=True
).to(device)
model.eval()
def decode_step(model, input_ids, past_key_values, cache_position):
out = model(
input_ids=input_ids,
past_key_values=past_key_values,
cache_position=cache_position,
use_cache=True,
)
logits = out.logits
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
return next_token, out.past_key_values
compiled_decode = torch.compile(decode_step, mode="reduce-overhead", fullgraph=True)
max_batch_size = 8
max_seq_len = 64
dtype = torch.float16
past_key_values = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=device,
dtype=dtype
)
batch_sizes = [8, 4, 2, 1]
try:
for bs in batch_sizes:
print(f"Batch Size: {bs}")
past_key_values.reset()
seq_len = 3
input_ids = torch.randint(0, 1000, (bs, seq_len), device=device)
cache_position = torch.arange(seq_len, device=device)
with torch.no_grad():
out = model(input_ids, past_key_values=past_key_values, cache_position=cache_position)
cur_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
cache_position = torch.tensor([seq_len], device=device)
cur_token, _ = compiled_decode(model, cur_token, past_key_values, cache_position)
print("Success")
except Exception as e:
print(f"Failed on {bs} with error {e}")
@i3hz yeah because the issue is that, at some point it returns self.keys and self.value , not just for Whisper, but also for other models. The self.keys_ / self.values_ trick works, I think we just need to update the reset() function so that it updates self.keys_ / self.values_ instead
I've implemented the self.keys_ and self.values_ functionality .
Along with that I also had to override the update ,reset ,__len__ for StaticCache (to triggers updates for cross-attention)
In Static Layer I've overridden the reset method as well .(to correctly reset the cache)
And I've also added max_batch_size parameter in StaticCache and StaticLayer .
So the testing script from earlier does work . But torch compile still fails with a segmentation fault which I'm working on . Is this the expected fix @zucchini-nlp @mobicham
(misclicked and accidentally closed the pr mb)
class StaticLayer(CacheLayerMixin):
"""
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
"""
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int, max_batch_size: int | None = None):
super().__init__()
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
def lazy_initialization(self, key_states: torch.Tensor):
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
if self.max_batch_size is None:
self.max_batch_size = key_states.shape[0]
_, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.keys = self.keys_
self.values = self.values_
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
# prefill explicitly, but this should be avoided!)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_)
torch._dynamo.mark_static_address(self.values_)
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states)
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
# 3. Dynamic Slicing: Update the view to match current batch
self.keys = self.keys_[:batch_size]
self.values = self.values_[:batch_size]
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def reset(self):
if self.is_initialized:
self.keys_.zero_()
self.values_.zero_()
self.keys = self.keys_
self.values = self.values_
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.max_cache_len
this is my StaticLayer class
However, the problem is that we can't do this for every modeling file separately. I guess the solution is to do something with self.keys and self.values, like this, it works with torch.compile:
Ah right, encoder-decoder ones are a bit different. Naming the kv with max batch size differently sounds good to me. Probably a bit more informative name would be better, it's easy to lose track when reading the code
@i3hz can you push the code you have with all the updates. Also, in which cases you're getting a seg fault, in test files or in bench script? It is important to not compile a prefill stage in custom generation loop, or if we have to compile in advance then cache has to be early initialized. The lazy init function is known to fail when compiled
The code is a bit messy but I'll change it later sorry
Probably a bit more informative name would be better, it's easy to lose track when reading the code
Yeah probably self.keys_, self.values_ ->self.keys_,full self.values_full or something like that
But torch compile still fails with a segmentation fault which I'm working on . @i3hz do you still have this issue? torch.compile works fine with the
self.keys_, self.values_trick, at least with Whisper, are other models not working too?
@i3hz do you still have this issue? torch.compile works fine with the self.keys_, self.values_ trick, at least with Whisper, are other models not working too?
It failed on my end . Can you please look into the implementation and lmk if i messed something up?
Also, in which cases you're getting a seg fault, in test files or in bench script? It is important to not compile a prefill stage in custom generation loop, or if we have to compile in advance then cache has to be early initialized. The lazy init function is known to fail when compiled
Yeah I forgot about this sorry . I split the test into eager prefill + compile decode and it now passes the test successfully .
But it is failing the CI tests , which is something I'm working on
I would really appreciate any pointers on why the CI tests are failing @zucchini-nlp
Welp all tests do pass now . I've also refactored the code a little bit . Lmk if there's any changes to be done
Thanks @i3hz , I will test compile with your latest changes on Monday
@i3hz your version gives me Segmentation fault (core dumped). You also need to assign self.keys and self.values as the truncated cache not the full cache, I don't think it's possible to use self.keys / self.values directly.
This works and passes the compile test
BUT the output is incorrect:
class StaticLayer(CacheLayerMixin):
"""
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
max_batch_size(`int`, *optional*):
Maximum batch size that can be stored
"""
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int, max_batch_size: int | None = None):
super().__init__()
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
def lazy_initialization(self, key_states: torch.Tensor):
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
if self.max_batch_size is None:
self.max_batch_size = key_states.shape[0]
_, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.keys, self.values = self.keys_full, self.values_full
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
# prefill explicitly, but this should be avoided!)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys)
torch._dynamo.mark_static_address(self.values)
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states)
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"
self.keys = self.keys_full[:batch_size]
self.values = self.values_full[:batch_size]
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def reset(self):
if self.is_initialized:
self.keys_full.zero_()
self.values_full.zero_()
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.max_cache_len
Test
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder(model, labels, encoder_outputs, past_key_values, prefill: bool):
seq_length = labels.shape[-1]
if(prefill):
cache_position = torch.arange(seq_length, device=device)
else:
cache_position = torch.tensor([seq_length], device=device)
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
def create_cache(max_batch_size):
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
return EncoderDecoderCache(self_cache, cross_cache)
max_batch_size = 8
past_key_values = create_cache(max_batch_size)
for _ in range(2):
for bs in [8, 4, 2, 1]:
past_key_values.reset()
assert bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * bs, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
print(f"{bs} pass!")
run_encoder = torch.compile(run_encoder, mode='reduce-overhead', fullgraph=True)
print('----- compiled run ---- ')
Output
8 pass!
4 pass!
2 pass!
1 pass!
----- compiled run ----
8 pass!
4 pass!
2 pass!
1 pass!
@i3hz can you push the code you have with all the updates. Also, in which cases you're getting a seg fault, in test files or in bench script? It is important to not compile a prefill stage in custom generation loop, or if we have to compile in advance then cache has to be early initialized. The lazy init function is known to fail when compiled
The script you provided is trying to compile the prefill stage as well which might be the reason for the segmentation fault @mobicham .
BUT the output is incorrect
wdym
@i3hz it doesn't matter which stage is compiled actually, I was getting seg fault even without compile 🤔
BUT the output is incorrect
For Whisper only, trying to see what's going
class StaticLayer(CacheLayerMixin):
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int, max_batch_size: int | None = None):
super().__init__()
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
self._current_batch_size = None # Track current batch size
def lazy_initialization(self, key_states: torch.Tensor):
if self.max_batch_size is None:
self.max_batch_size = key_states.shape[0]
_, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_full)
torch._dynamo.mark_static_address(self.values_full)
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.is_initialized:
self.lazy_initialization(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
self._current_batch_size = batch_size
# Slice to current batch size
self.keys = self.keys_full[:batch_size]
self.values = self.values_full[:batch_size]
# Use slice assignment instead of index_copy_ for better compatibility
# with variable batch sizes
if cache_position.numel() == self.max_cache_len:
# Full cache update (e.g., cross-attention initialization)
self.keys.copy_(key_states)
self.values.copy_(value_states)
else:
# Incremental update
try:
# Try index_copy_ first (faster when it works)
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except (NotImplementedError, RuntimeError):
# Fallback to direct indexing
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def reset(self):
if self.is_initialized:
self.keys_full.zero_()
self.values_full.zero_()
self._current_batch_size = None
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
return self.max_cache_len
class StaticSlidingWindowLayer(StaticLayer):
"""
A static cache layer that stores the key and value states as static tensors of shape
`[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
sliding_window (`int`):
The size of the sliding window.
max_batch_size(`int`, *optional*):
Maximum batch size that can be stored
"""
is_sliding = True
def __init__(self, max_cache_len: int, sliding_window: int, max_batch_size: int | None = None):
effective_max_cache_len = min(sliding_window, max_cache_len)
super().__init__(max_cache_len=effective_max_cache_len, max_batch_size=max_batch_size)
self.cumulative_length = 0
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states)
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
self.keys = self.keys_full[:batch_size]
self.values = self.values_full[:batch_size]
cumulative_length = self.cumulative_length
is_full = cumulative_length >= self.max_cache_len
# Update it now that we saved the value above
self.cumulative_length += key_states.shape[-2]
if is_full:
# In general, we should use a much simpler `cat` here as well, independently of the states size. However,
# dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
if key_states.shape[-2] == 1:
# Roll all values to the left by 1 position
new_keys = self.keys.roll(-1, dims=-2)
new_values = self.values.roll(-1, dims=-2)
# Overwrite the last position with new states
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
index = torch.tensor([-1], dtype=int, device=self.device)
new_keys[:, :, index] = key_states
new_values[:, :, index] = value_states
# Copy back into the batch-sliced portion (NOT the full cache)
self.keys[:batch_size].copy_(new_keys)
self.values[:batch_size].copy_(new_values)
# Return the batch-sliced view
return self.keys, self.values
# Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
else:
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
# Not yet full, but becoming full on this update
elif cumulative_length + key_states.shape[2] > self.max_cache_len:
# Fast prefill path, no need to cat() in this case, as the cache is currently empty
if cumulative_length == 0:
full_key_states = key_states
full_value_states = value_states
else:
full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
else:
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
# Return the batch-sliced view
return self.keys, self.values
# We only cache the last `sliding_window` tokens
self.keys[:batch_size].copy_(full_key_states[:, :, -self.max_cache_len :, :])
self.values[:batch_size].copy_(full_value_states[:, :, -self.max_cache_len :, :])
# we should return the whole states instead of sliced cache here, as otherwise we lose some context
return full_key_states, full_value_states
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
query_length = cache_position.shape[0]
sliding_window = self.max_cache_len
is_full = self.cumulative_length >= self.max_cache_len
kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
# The cache is already full
if is_full:
kv_length = sliding_window + query_length - 1
# Not yet full, but becoming full on this update
elif self.cumulative_length + query_length > sliding_window:
kv_length = self.cumulative_length + query_length
# Here the Cache is still smaller than the local size, but we return the local size as it's static
else:
kv_length = sliding_window
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length
def reset(self):
if self.is_initialized:
self.keys.zero_()
self.values.zero_()
self.cumulative_length = 0
@mobicham can you check if this one works ?
@i3hz yes it works 👍
The incorrect output is not related to the static cache code update we are doing here, the output is incorrect with Whisper + static even without the code update 🤔 , so probably an unrelated bug.
The incorrect output is not related to the static cache code update we are doing here, the output is incorrect with Whisper + static even without the code update 🤔 , so probably an unrelated bug.
All good, the error was on my side
What should I add as a test?
Wait this test
FAILED tests/models/qwen3_omni_moe/test_processing_qwen3_omni_moe.py::Qwen3OmniMoeProcessorTest::test_apply_chat_template_video_frame_sampling - RuntimeError: Failed to open input buffer: Invalid data found when processing input
failed on the Ci
but it passes on my local machine ?
(venv) ➜ transformers git:(static_cache) pytest tests/models/qwen3_omni_moe/test_processing_qwen3_omni_moe.py::Qwen3OmniMoeProcessorTest::test_apply_chat_template_video_frame_sampling
================================================= test session starts ==================================================
platform linux -- Python 3.12.3, pytest-8.4.2, pluggy-1.6.0
rootdir: /home/vedth/stuhdy/transformers
configfile: pyproject.toml
plugins: timeout-2.4.0, rich-0.2.0, xdist-3.8.0, rerunfailures-15.1, anyio-4.11.0, order-1.3.0, asyncio-1.3.0, hypothesis-6.148.3
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function
collected 1 item
tests/models/qwen3_omni_moe/test_processing_qwen3_omni_moe.py::Qwen3OmniMoeProcessorTest::test_apply_chat_template_video_frame_sampling PASSED [100%]
So I don't really know . @zucchini-nlp any reason why this could happen?
The failures on qwen3 aren't related, that prob happended due to too many requests which caused failure to download image data. I will take a look at this PR around this week, have been quite busy with other tasks lately
In the meantime, can you make sure that PR is ready, and has the final bench script, performance results in the description. Also, I think it's nice to allow users to initialize a cache with max batch size from model.generate() call, so we can pass over the param to the cache init when generating. That way we can also test if the feature aligns well with auto-compile in generation loop
Btw can you also do a benchmark with speed, based on the logs, it seems this creates an issue with cudagraphs, I am not sure though if this is critical:
shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim)) and attn_output = torch.nn.functional.scaled_dot_product_attention( # transformers/integrations/sdpa_attention.py:96 in sdpa_attention_forward (_dynamo/utils.py:3421 in run_node)
V1208 17:05:18.623000 743734 torch/_dynamo/guards.py:3508] [0/2] [__recompiles] - 0/0: tensor
'fn.__self__.past_key_values.cross_attention_cache.layers[0].keys' size mismatch at index 0. expected 128, actual 1
I1208 17:05:21.047000 743734 torch/_inductor/cudagraph_trees.py:390] [__cudagraphs] recording cudagraph tree for graph without symints
V1208 17:05:21.048000 743734 torch/_inductor/cudagraph_trees.py:2256] [__cudagraphs] Running warmup of function 9
V1208 17:05:21.055000 743734 torch/_inductor/cudagraph_trees.py:2213] [__cudagraphs] Recording function 9 of graph recording id 9