No support of GQA of Llama in real_drop
In modify_llama.py, the hh_score of H2OCache is computed by attn_scores.sum(0).sum(1), resulting in a shape of [num_heads, hidden_dim]. However, in Llama's GQA implementation(just in the same file), the k/v cache has a shape of [B, num_key_value_heads, ....], which mismatches the hh_score.
I manually implement it with eager torch attention kernel. With either "keep the first head of each group", "get mean of each group", "get sum of each group", or repeat the key/value states in cache, h2o gives me an unacceptable result like :
I almostly copied these lines so it seems not an implementation mistake.
here is the implementation and DynamicCache is same as transformers' DynamicCache.
class H2OCache(DynamicCache):
##inheritance from DynamicCache is not a bug.
##it matches the official code.
def __init__(self, max_len, device, num_key_value_heads, num_kv_groups,
hh_size=128,
recent_size=512,):
super().__init__(max_len, device)
self.num_key_value_heads = num_key_value_heads
self.num_kv_groups = num_kv_groups
self.recent_size = recent_size
self.hh_size = hh_size
self.hh_scores = []
def _update_hh_scores(self, attn_score_cache, layer_idx):
##check https://github.com/FMInference/H2O/blob/281ffef3f1432ceb1a6899362d2f20e1ef13aa94/h2o_hf/utils_real_drop/modify_llama.py#L290
num_new_tokens = attn_score_cache.shape[2]
if len(self.hh_scores) <= layer_idx:
hh_score = attn_score_cache.sum(0).sum(1)
self.hh_scores.append(hh_score)
else:
attn_score_cache = attn_score_cache.sum(0).sum(1)##[B, H, Q, K]->[H,Q,K]->[H, K].
attn_score_cache[:, :-num_new_tokens] += self.hh_scores[layer_idx]##[H, K]
##can't work with GQA
self.hh_scores[layer_idx] = attn_score_cache
def evict(self, attn_scores, layer_idx):
"""
attn_scores:[B, H, Q, K]
"""
##not quite same as the paper.
##here use a top-k selection.
self._update_hh_scores(attn_scores, layer_idx)
bsz, num_heads, q_len, k_len = attn_scores.shape##k_len = cache len(after store)
seq_len = k_len
if k_len <= self.recent_size + self.hh_size:
return
head_dim = self.key_cache[layer_idx].shape[-1]
select_hh_scores = self.hh_scores[layer_idx][:, :seq_len - self.recent_size]
_, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
keep_topk = keep_topk.sort().values
# keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
hh_score = self.hh_scores[layer_idx]
mask = torch.zeros(hh_score.shape, dtype=torch.bool).to(self.key_cache[layer_idx].device)
mask = mask.scatter(-1, keep_idx, 1)
k_hh_recent = self.key_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
v_hh_recent = self.value_cache[layer_idx].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
self.hh_scores[layer_idx] = hh_score[mask].view(num_heads, self.hh_size + self.recent_size)
self.key_cache[layer_idx] = k_hh_recent
self.value_cache[layer_idx] = v_hh_recent
Supplementary:
Well that's a bug hiding in rotary embedding(which is not here so I didn't find it...). Re-applying rotary embedding at each step is neccessary so I modify the whole DynamicCache class.
Though, the performance is not outstanding, shown below:
A dedicate support of GQA is needed(here is 'repeat_interleave' implementation. I manually repeated the kv cache. It is closest to paper but least efficient).