H2O icon indicating copy to clipboard operation
H2O copied to clipboard

HH scores summed along batch dimension

Open yeoedward opened this issue 1 year ago • 4 comments

The hh scores seem to be summed along the batch dimension, which is strange as they are sequence-dependent. Shouldn't separate hh scores be maintained for each sequence in a batch?

Code: https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L132

Also, thanks for open sourcing your code!

yeoedward avatar Dec 20 '23 06:12 yeoedward

@yeoedward @Ying1123 @Kyriection Hi,is there an answer for the above question? Besides,I also want to know when bathcing inference is used for llama, how to update the hh_socre?

ChuanhongLi avatar Jan 04 '24 09:01 ChuanhongLi

Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)

Kyriection avatar Jan 04 '24 15:01 Kyriection

Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)

@Kyriection Thanks for your reply. I have changed the code to support batching inference, just as following. The recent_sze = 100, and hh_size = 24, works well for batch size = 1. However, when batch size is set to 2, the output is garbled(when the seq len is larger than 124(100+24)). Something wrong with the changed code?

class H2OKVCache_LayerWise:
    def __init__(
            self,
            hh_size=24,
            recent_size=1000,
            k_seq_dim=2,
            v_seq_dim=2,
    ):
        print(f"H2OKVCache-LayerWise: {hh_size}, {recent_size}")
        self.hh_size = hh_size
        self.recent_size = recent_size
        self.cache_size = hh_size + recent_size
        self.k_seq_dim = k_seq_dim
        self.v_seq_dim = v_seq_dim
        self.hh_score = None

    def __call__(self, past_key_values, attn_score_cache):

        self._update_hh_score(attn_score_cache)
        if past_key_values is None:
            return None
        seq_len = past_key_values[0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        # seq_len:  116
        # past_key_values[0]:  torch.Size([2, 52, 116, 128])
        # hh-selection
        bsz, num_heads, _, head_dim = past_key_values[0].shape
        k_hh_recent = None
        v_hh_recent = None
        for i in range(0, bsz):
            select_hh_scores = self.hh_score[i][:, :seq_len - self.recent_size]
            _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
            keep_topk = keep_topk.sort().values
            # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
            keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(
                keep_topk.shape[0], 1)

            keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)

            mask = torch.zeros(self.hh_score[i].shape, dtype=torch.bool).to(past_key_values[0].device)
            mask = mask.scatter(-1, keep_idx, 1)

            k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

            if k_hh_recent is None:
                k_hh_recent = k_hh_recent1
                v_hh_recent = v_hh_recent1
            else:
                k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=2)
                v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=2)

            self.hh_score[i] = self.hh_score[i][mask].view(num_heads, self.cache_size)

        return (k_hh_recent, v_hh_recent)

    def _update_hh_score(self, attn_score_cache):

        num_new_tokens = attn_score_cache.shape[2]
        temp_hh_score = []
        if self.hh_score is None:
            for i in range(0, len(attn_score_cache)):
                temp_hh_score.append(attn_score_cache[i].sum(1))
            self.hh_score = temp_hh_score
        else:
            for i in range(0, len(attn_score_cache)):
                temp_score_cache = attn_score_cache[i].sum(1)
                temp_score_cache[:, :-num_new_tokens] += self.hh_score[i]
                self.hh_score[i] = temp_score_cache
                
    def _clean_scores(self):
        self.hh_score = None

ChuanhongLi avatar Jan 05 '24 01:01 ChuanhongLi

Hi, The HH scores should be sequence-independent. In this implementation, we use one sequence in each batch for testing. Will update the implementation for multi sequences shortly, by modifying (https://github.com/FMInference/H2O/blob/main/h2o_hf/utils_real_drop/modify_llama.py#L269)

@Kyriection Thanks for your reply. I have changed the code to support batching inference, just as following. The recent_sze = 100, and hh_size = 24, works well for batch size = 1. However, when batch size is set to 2, the output is garbled(when the seq len is larger than 124(100+24)). Something wrong with the changed code?

class H2OKVCache_LayerWise:
    def __init__(
            self,
            hh_size=24,
            recent_size=1000,
            k_seq_dim=2,
            v_seq_dim=2,
    ):
        print(f"H2OKVCache-LayerWise: {hh_size}, {recent_size}")
        self.hh_size = hh_size
        self.recent_size = recent_size
        self.cache_size = hh_size + recent_size
        self.k_seq_dim = k_seq_dim
        self.v_seq_dim = v_seq_dim
        self.hh_score = None

    def __call__(self, past_key_values, attn_score_cache):

        self._update_hh_score(attn_score_cache)
        if past_key_values is None:
            return None
        seq_len = past_key_values[0].size(self.k_seq_dim)
        if seq_len <= self.cache_size:
            return past_key_values
        # seq_len:  116
        # past_key_values[0]:  torch.Size([2, 52, 116, 128])
        # hh-selection
        bsz, num_heads, _, head_dim = past_key_values[0].shape
        k_hh_recent = None
        v_hh_recent = None
        for i in range(0, bsz):
            select_hh_scores = self.hh_score[i][:, :seq_len - self.recent_size]
            _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
            keep_topk = keep_topk.sort().values
            # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
            keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(
                keep_topk.shape[0], 1)

            keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)

            mask = torch.zeros(self.hh_score[i].shape, dtype=torch.bool).to(past_key_values[0].device)
            mask = mask.scatter(-1, keep_idx, 1)

            k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
            v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)

            if k_hh_recent is None:
                k_hh_recent = k_hh_recent1
                v_hh_recent = v_hh_recent1
            else:
                k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=2)
                v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=2)

            self.hh_score[i] = self.hh_score[i][mask].view(num_heads, self.cache_size)

        return (k_hh_recent, v_hh_recent)

    def _update_hh_score(self, attn_score_cache):

        num_new_tokens = attn_score_cache.shape[2]
        temp_hh_score = []
        if self.hh_score is None:
            for i in range(0, len(attn_score_cache)):
                temp_hh_score.append(attn_score_cache[i].sum(1))
            self.hh_score = temp_hh_score
        else:
            for i in range(0, len(attn_score_cache)):
                temp_score_cache = attn_score_cache[i].sum(1)
                temp_score_cache[:, :-num_new_tokens] += self.hh_score[i]
                self.hh_score[i] = temp_score_cache
                
    def _clean_scores(self):
        self.hh_score = None

Maybe I see it.

# k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
# v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
k_hh_recent1 = past_key_values[0][i].squeeze()[mask].view(1, num_heads, -1, head_dim)
v_hh_recent1 = past_key_values[1][i].squeeze()[mask].view(1, num_heads, -1, head_dim)
# print("line 52 k_hh_recent: ", k_hh_recent1.shape)
 # print("line 53 v_hh_recent: ", v_hh_recent1.shape)
if k_hh_recent is None:
     k_hh_recent = k_hh_recent1
     v_hh_recent = v_hh_recent1
else:
     k_hh_recent = torch.cat([k_hh_recent, k_hh_recent1], dim=0)
     v_hh_recent = torch.cat([v_hh_recent, v_hh_recent1], dim=0)

Just update the generation of k_hh_recent and v_hh_recent, the code works.

ChuanhongLi avatar Jan 05 '24 03:01 ChuanhongLi