LLMSpeculativeSampling
LLMSpeculativeSampling copied to clipboard
Bloom's kvcache are both (bs, head, seq, head_dim) in new version of transformers
Hi, thanks for your awesome demo of speculative sample. Some of your code maybe outdated in new version of transformer.
In the KVCacheModel
class, Bloom model' k cache shape is [bs * head, head_dim, seq] v cache shape is [bs*head, seq, head_dim]
but in transformers 4.44.2, both kv cache shape are (bs, head, seq, head_dim), so this function doesn't work any more.
@torch.no_grad()
def rollback(self, end_pos : int):
past_key_values_trimmed = []
assert self._past_key_values
for kv in self._past_key_values:
k, v = kv
# NOTE() the indexing is specific for bloom. This won't work for other models
# For example llama k, v should be (batch, num_head, seq_len, hidden_dim)
# Bloom is special one
if isinstance(self._model, BloomForCausalLM):
# k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim)
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
else:
# k, v (batch, head, seq, hidden_dim)
k = k[:, :, :end_pos, :]
v = v[:, :, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)
self._past_key_values = past_key_values_trimmed
self._prob_history = self._prob_history[:, :end_pos, :]
Here is my debug information: