LLMSpeculativeSampling icon indicating copy to clipboard operation
LLMSpeculativeSampling copied to clipboard

Bloom's kvcache are both (bs, head, seq, head_dim) in new version of transformers

Open shadow150519 opened this issue 4 months ago • 0 comments

Hi, thanks for your awesome demo of speculative sample. Some of your code maybe outdated in new version of transformer. In the KVCacheModel class, Bloom model' k cache shape is [bs * head, head_dim, seq] v cache shape is [bs*head, seq, head_dim] but in transformers 4.44.2, both kv cache shape are (bs, head, seq, head_dim), so this function doesn't work any more.

@torch.no_grad()
    def rollback(self, end_pos : int):
        past_key_values_trimmed = []
        assert self._past_key_values
        for kv in self._past_key_values:
            k, v = kv
            # NOTE() the indexing is specific for bloom. This won't work for other models
            # For example llama k, v should be (batch, num_head, seq_len, hidden_dim)
            
            # Bloom is special one
            if isinstance(self._model, BloomForCausalLM):
                # k (batch * head, hidden_dim, seq); v (batch * head, seq, hidden_dim)
                k = k[:, :, :end_pos]
                v = v[:, :end_pos, :]
                kv_trimmed = (k, v)
                past_key_values_trimmed.append(kv_trimmed)
            else:
                # k, v (batch, head, seq, hidden_dim)
                k = k[:, :, :end_pos, :]
                v = v[:, :, :end_pos, :]
                kv_trimmed = (k, v)
                past_key_values_trimmed.append(kv_trimmed)
        
        self._past_key_values = past_key_values_trimmed
        self._prob_history = self._prob_history[:, :end_pos, :]

Here is my debug information: image

shadow150519 avatar Oct 06 '24 05:10 shadow150519