PainlessInferenceAcceleration icon indicating copy to clipboard operation
PainlessInferenceAcceleration copied to clipboard

请问:batch_indices 是什么含义? LookaheadCache 的put、stream_put函数的最后一个参数ids是什么含义? class Tree(): 里面的idx代表什么含义?

Open handsome-chips opened this issue 6 months ago • 0 comments

请问:batch_indices 是什么含义? LookaheadCache 的put、stream_put函数的最后一个参数idx是什么含义? mode和 idx 直接有关系吗? 为啥mode='input', idx = 0,1,2,3.. batch的值,为啥mode='output', idx=-1 ?

https://github.com/alipay/PainlessInferenceAcceleration/blob/8015f12f7fe32acc102bb3eb51c4f8b3a420e79c/pia/lookahead/common/pretrained_model_batch.py#L1254-L1259

def put(self, token_ids, branch_length=8, final=False, mode='output', idx=0):

为什么 idx is only used for caching output_ids

def stream_put(self, token_ids, branch_length=8, final=False, mode='output', idx=0):
    # idx is only used for caching output_ids

class Tree 里面的idx 代表什么意思? https://github.com/alipay/PainlessInferenceAcceleration/blob/8015f12f7fe32acc102bb3eb51c4f8b3a420e79c/pia/lookahead/common/lookahead_cache.py#L24-L63

idx 有哪些取值?这些取值分别代表什么含义? 跟多batch的batch下标有什么关系吗?

请问 class Tree()的get()函数的 idx参数和 返回值size = [, ] 是什么含义? assert mode in ('input', 'output', 'mix') 这三种模式是什么意思?

    def get(self, token_ids, max_size=64, max_length=8, min_input_size=0,
            min_output_size=0, output_weight=1e-4, mode='mix', idx=0):
        assert mode in ('input', 'output', 'mix')

        match_token_id, nodes = self._match(token_ids, mode=mode, idx=idx)
        if len(nodes) == 0:
            token_id = token_ids[-1] if len(token_ids) > 0 else self.token_id
            return [token_id], np.ones((1, 1), dtype=np.int64), [0, 0]

        freqs = []
        self._dfs_get_freqs(nodes, freqs, idx, output_weight)
        # self._bfs_get_freqs(nodes, freqs, idx, output_weight)

        min_mix_freq = 1e9
        min_input_freq = 1e9
        min_output_freq = 1e9
        if mode == 'input':
            output_weight = 0.0
            size = len([x for x in freqs if x[1] > 0])
            if size > max_size:
                input_freqs = sorted(freqs, key=lambda x: x[1], reverse=True)
                min_input_freq = input_freqs[min_input_size - 1][1]
            else:
                min_input_freq = 0.0
        elif mode == 'output':
            output_weight = 1.0
            size = len([x for x in freqs if x[2] > 0])
            if size > max_size:
                output_freqs = sorted(freqs, key=lambda x: x[2], reverse=True)
                min_output_freq = output_freqs[min_output_size - 1][2]
            else:
                min_output_freq = 0.0
        else:
            size = len([x for x in freqs if x[1] > 0 or x[2] > 0])
            if size > max_size:
                indices = set()
                if min_input_size > 0:
                    input_freqs = sorted(freqs, key=lambda x: x[1], reverse=True)
                    min_input_freq = input_freqs[min_input_size - 1][1]
                    indices.update([x[0] for x in input_freqs[:min_input_size]])

                if min_output_size > 0:
                    output_freqs = sorted(freqs, key=lambda x: x[2], reverse=True)
                    min_output_freq = output_freqs[min_output_size - 1][2]
                    indices.update([x[0] for x in output_freqs[:min_output_size]])

                if len(indices) < max_size:
                    mix_freqs = sorted(freqs, key=lambda x: x[3], reverse=True)
                    rest_size = max_size - len(indices)
                    indices.update([x[0] for x in mix_freqs[:rest_size]])
                    cur_size = len(indices)
                    for i in range(rest_size, min(rest_size + max_size, size)):
                        if mix_freqs[i][0] in indices:
                            continue
                        cur_size += 1
                        if cur_size >= max_size:
                            x = mix_freqs[i]
                            min_mix_freq = x[3]
                            break
            else:
                min_mix_freq = 0.0

        mask = np.zeros((max_size, max_size), dtype=np.int64)
        mask[:, 0] = 1
        ids = [match_token_id or self.token_id]
        sizes = [0, 0]
        self._ravel(nodes, ids, mask, -1,
                    max_size=max_size,
                    max_length=max_length,
                    min_output_freq=min_output_freq,
                    min_input_freq=min_input_freq,
                    min_mix_freq=min_mix_freq,
                    sizes=sizes,
                    output_weight=output_weight,
                    mode=mode,
                    idx=idx)
        size = len(ids)

        mask = mask[:size, :size]
        return ids, mask, sizes

handsome-chips avatar Aug 05 '24 14:08 handsome-chips