H2O
H2O copied to clipboard
HH scores summed along batch dimension
The hh scores seem to be summed along the batch dimension, which is strange as they are sequence-dependent. Shouldn't separate hh scores be maintained for each sequence in a batch?
Code: https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L132
Also, thanks for open sourcing your code!
@yeoedward @Ying1123 @Kyriection Hi,is there an answer for the above question? Besides,I also want to know when bathcing inference is used for llama, how to update the hh_socre?
Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)
Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)
@Kyriection Thanks for your reply. I have changed the code to support batching inference, just as following. The recent_sze = 100, and hh_size = 24, works well for batch size = 1. However, when batch size is set to 2, the output is garbled(when the seq len is larger than 124(100+24)). Something wrong with the changed code?
class H2OKVCache_LayerWise:
def __init__(
self,
hh_size=24,
recent_size=1000,
k_seq_dim=2,
v_seq_dim=2,
):
print(f"H2OKVCache-LayerWise: {hh_size}, {recent_size}")
self.hh_size = hh_size
self.recent_size = recent_size
self.cache_size = hh_size + recent_size
self.k_seq_dim = k_seq_dim
self.v_seq_dim = v_seq_dim
self.hh_score = None
def __call__(self, past_key_values, attn_score_cache):
self._update_hh_score(attn_score_cache)
if past_key_values is None:
return None
seq_len = past_key_values[0].size(self.k_seq_dim)
if seq_len <= self.cache_size:
return past_key_values
# seq_len: 116
# past_key_values[0]: torch.Size([2, 52, 116, 128])
# hh-selection
bsz, num_heads, _, head_dim = past_key_values[0].shape
k_hh_recent = None
v_hh_recent = None
for i in range(0, bsz):
select_hh_scores = self.hh_score[i][:, :seq_len - self.recent_size]
_, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
keep_topk = keep_topk.sort().values
# keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(
keep_topk.shape[0], 1)
keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
mask = torch.zeros(self.hh_score[i].shape, dtype=torch.bool).to(past_key_values[0].device)
mask = mask.scatter(-1, keep_idx, 1)
k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
if k_hh_recent is None:
k_hh_recent = k_hh_recent1
v_hh_recent = v_hh_recent1
else:
k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=2)
v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=2)
self.hh_score[i] = self.hh_score[i][mask].view(num_heads, self.cache_size)
return (k_hh_recent, v_hh_recent)
def _update_hh_score(self, attn_score_cache):
num_new_tokens = attn_score_cache.shape[2]
temp_hh_score = []
if self.hh_score is None:
for i in range(0, len(attn_score_cache)):
temp_hh_score.append(attn_score_cache[i].sum(1))
self.hh_score = temp_hh_score
else:
for i in range(0, len(attn_score_cache)):
temp_score_cache = attn_score_cache[i].sum(1)
temp_score_cache[:, :-num_new_tokens] += self.hh_score[i]
self.hh_score[i] = temp_score_cache
def _clean_scores(self):
self.hh_score = None
Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)
@Kyriection Thanks for your reply. I have changed the code to support batching inference, just as following. The recent_sze = 100, and hh_size = 24, works well for batch size = 1. However, when batch size is set to 2, the output is garbled(when the seq len is larger than 124(100+24)). Something wrong with the changed code?
class H2OKVCache_LayerWise: def __init__( self, hh_size=24, recent_size=1000, k_seq_dim=2, v_seq_dim=2, ): print(f"H2OKVCache-LayerWise: {hh_size}, {recent_size}") self.hh_size = hh_size self.recent_size = recent_size self.cache_size = hh_size + recent_size self.k_seq_dim = k_seq_dim self.v_seq_dim = v_seq_dim self.hh_score = None def __call__(self, past_key_values, attn_score_cache): self._update_hh_score(attn_score_cache) if past_key_values is None: return None seq_len = past_key_values[0].size(self.k_seq_dim) if seq_len <= self.cache_size: return past_key_values # seq_len: 116 # past_key_values[0]: torch.Size([2, 52, 116, 128]) # hh-selection bsz, num_heads, _, head_dim = past_key_values[0].shape k_hh_recent = None v_hh_recent = None for i in range(0, bsz): select_hh_scores = self.hh_score[i][:, :seq_len - self.recent_size] _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1) keep_topk = keep_topk.sort().values # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device) keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat( keep_topk.shape[0], 1) keep_idx = torch.cat([keep_topk, keep_recent], dim=-1) mask = torch.zeros(self.hh_score[i].shape, dtype=torch.bool).to(past_key_values[0].device) mask = mask.scatter(-1, keep_idx, 1) k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim) v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim) if k_hh_recent is None: k_hh_recent = k_hh_recent1 v_hh_recent = v_hh_recent1 else: k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=2) v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=2) self.hh_score[i] = self.hh_score[i][mask].view(num_heads, self.cache_size) return (k_hh_recent, v_hh_recent) def _update_hh_score(self, attn_score_cache): num_new_tokens = attn_score_cache.shape[2] temp_hh_score = [] if self.hh_score is None: for i in range(0, len(attn_score_cache)): temp_hh_score.append(attn_score_cache[i].sum(1)) self.hh_score = temp_hh_score else: for i in range(0, len(attn_score_cache)): temp_score_cache = attn_score_cache[i].sum(1) temp_score_cache[:, :-num_new_tokens] += self.hh_score[i] self.hh_score[i] = temp_score_cache def _clean_scores(self): self.hh_score = None
Maybe I see it.
# k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
# v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(1, num_heads, -1, head_dim)
v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(1, num_heads, -1, head_dim)
# print("line 52 k_hh_recent: ", k_hh_recent1.shape)
# print("line 53 v_hh_recent: ", v_hh_recent1.shape)
if k_hh_recent is None:
k_hh_recent = k_hh_recent1
v_hh_recent = v_hh_recent1
else:
k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=0)
v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=0)
Just update the generation of k_hh_recent and v_hh_recent, the code works.