H2O icon indicating copy to clipboard operation
H2O copied to clipboard

No support of GQA of Llama in real_drop

Open Tomorrowdawn opened this issue 1 year ago • 1 comments

In modify_llama.py, the hh_score of H2OCache is computed by attn_scores.sum(0).sum(1), resulting in a shape of [num_heads, hidden_dim]. However, in Llama's GQA implementation(just in the same file), the k/v cache has a shape of [B, num_key_value_heads, ....], which mismatches the hh_score.

I manually implement it with eager torch attention kernel. With either "keep the first head of each group", "get mean of each group", "get sum of each group", or repeat the key/value states in cache, h2o gives me an unacceptable result like :

5177706214c60061f4e854098449d29

I almostly copied these lines so it seems not an implementation mistake.

here is the implementation and DynamicCache is same as transformers' DynamicCache.

class H2OCache(DynamicCache):
    ##inheritance from DynamicCache is not a bug.
    ##it matches the official code.
    def __init__(self, max_len, device, num_key_value_heads, num_kv_groups,
        hh_size=128,
        recent_size=512,):
        super().__init__(max_len, device)
        self.num_key_value_heads = num_key_value_heads
        self.num_kv_groups = num_kv_groups

        self.recent_size = recent_size

        self.hh_size = hh_size
        self.hh_scores = []
    def _update_hh_scores(self, attn_score_cache, layer_idx):
        ##check https://github.com/FMInference/H2O/blob/281ffef3f1432ceb1a6899362d2f20e1ef13aa94/h2o_hf/utils_real_drop/modify_llama.py#L290
        
        num_new_tokens = attn_score_cache.shape[2]
        if len(self.hh_scores) <= layer_idx:
            hh_score = attn_score_cache.sum(0).sum(1)
            self.hh_scores.append(hh_score)
        else:
            attn_score_cache = attn_score_cache.sum(0).sum(1)##[B, H, Q, K]->[H,Q,K]->[H, K]. 
            attn_score_cache[:, :-num_new_tokens] += self.hh_scores[layer_idx]##[H, K]
            ##can't work with GQA
            self.hh_scores[layer_idx] = attn_score_cache

    def evict(self, attn_scores, layer_idx):
        """
        attn_scores:[B, H, Q, K]
        """
        ##not quite same as the paper. 
        ##here use a top-k selection. 
        self._update_hh_scores(attn_scores, layer_idx)

        bsz, num_heads, q_len, k_len = attn_scores.shape##k_len = cache len(after store)
        seq_len = k_len
        if k_len <= self.recent_size + self.hh_size:
            return
        head_dim = self.key_cache[layer_idx].shape[-1]
        
        select_hh_scores = self.hh_scores[layer_idx][:, :seq_len - self.recent_size]
        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
        keep_topk = keep_topk.sort().values

        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
        keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
        hh_score = self.hh_scores[layer_idx]
        mask = torch.zeros(hh_score.shape, dtype=torch.bool).to(self.key_cache[layer_idx].device)
        mask = mask.scatter(-1, keep_idx, 1)

        k_hh_recent = self.key_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
        v_hh_recent = self.value_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

        self.hh_scores[layer_idx] = hh_score[mask].view(num_heads, self.hh_size + self.recent_size)

        self.key_cache[layer_idx] = k_hh_recent
        self.value_cache[layer_idx] = v_hh_recent

Tomorrowdawn avatar May 27 '24 13:05 Tomorrowdawn

Supplementary:

Well that's a bug hiding in rotary embedding(which is not here so I didn't find it...). Re-applying rotary embedding at each step is neccessary so I modify the whole DynamicCache class.

Though, the performance is not outstanding, shown below:

image

A dedicate support of GQA is needed(here is 'repeat_interleave' implementation. I manually repeated the kv cache. It is closest to paper but least efficient).

Tomorrowdawn avatar May 27 '24 14:05 Tomorrowdawn