PainlessInferenceAcceleration
PainlessInferenceAcceleration copied to clipboard
请问:batch_indices 是什么含义? LookaheadCache 的put、stream_put函数的最后一个参数ids是什么含义? class Tree(): 里面的idx代表什么含义?
请问:batch_indices 是什么含义? LookaheadCache 的put、stream_put函数的最后一个参数idx是什么含义? mode和 idx 直接有关系吗? 为啥mode='input', idx = 0,1,2,3.. batch的值,为啥mode='output', idx=-1 ?
https://github.com/alipay/PainlessInferenceAcceleration/blob/8015f12f7fe32acc102bb3eb51c4f8b3a420e79c/pia/lookahead/common/pretrained_model_batch.py#L1254-L1259
def put(self, token_ids, branch_length=8, final=False, mode='output', idx=0):
为什么 idx is only used for caching output_ids
?
def stream_put(self, token_ids, branch_length=8, final=False, mode='output', idx=0):
# idx is only used for caching output_ids
class Tree 里面的idx 代表什么意思? https://github.com/alipay/PainlessInferenceAcceleration/blob/8015f12f7fe32acc102bb3eb51c4f8b3a420e79c/pia/lookahead/common/lookahead_cache.py#L24-L63
idx 有哪些取值?这些取值分别代表什么含义? 跟多batch的batch下标有什么关系吗?
请问 class Tree()
的get()函数的 idx参数和 返回值size = [, ] 是什么含义?
assert mode in ('input', 'output', 'mix') 这三种模式是什么意思?
def get(self, token_ids, max_size=64, max_length=8, min_input_size=0,
min_output_size=0, output_weight=1e-4, mode='mix', idx=0):
assert mode in ('input', 'output', 'mix')
match_token_id, nodes = self._match(token_ids, mode=mode, idx=idx)
if len(nodes) == 0:
token_id = token_ids[-1] if len(token_ids) > 0 else self.token_id
return [token_id], np.ones((1, 1), dtype=np.int64), [0, 0]
freqs = []
self._dfs_get_freqs(nodes, freqs, idx, output_weight)
# self._bfs_get_freqs(nodes, freqs, idx, output_weight)
min_mix_freq = 1e9
min_input_freq = 1e9
min_output_freq = 1e9
if mode == 'input':
output_weight = 0.0
size = len([x for x in freqs if x[1] > 0])
if size > max_size:
input_freqs = sorted(freqs, key=lambda x: x[1], reverse=True)
min_input_freq = input_freqs[min_input_size - 1][1]
else:
min_input_freq = 0.0
elif mode == 'output':
output_weight = 1.0
size = len([x for x in freqs if x[2] > 0])
if size > max_size:
output_freqs = sorted(freqs, key=lambda x: x[2], reverse=True)
min_output_freq = output_freqs[min_output_size - 1][2]
else:
min_output_freq = 0.0
else:
size = len([x for x in freqs if x[1] > 0 or x[2] > 0])
if size > max_size:
indices = set()
if min_input_size > 0:
input_freqs = sorted(freqs, key=lambda x: x[1], reverse=True)
min_input_freq = input_freqs[min_input_size - 1][1]
indices.update([x[0] for x in input_freqs[:min_input_size]])
if min_output_size > 0:
output_freqs = sorted(freqs, key=lambda x: x[2], reverse=True)
min_output_freq = output_freqs[min_output_size - 1][2]
indices.update([x[0] for x in output_freqs[:min_output_size]])
if len(indices) < max_size:
mix_freqs = sorted(freqs, key=lambda x: x[3], reverse=True)
rest_size = max_size - len(indices)
indices.update([x[0] for x in mix_freqs[:rest_size]])
cur_size = len(indices)
for i in range(rest_size, min(rest_size + max_size, size)):
if mix_freqs[i][0] in indices:
continue
cur_size += 1
if cur_size >= max_size:
x = mix_freqs[i]
min_mix_freq = x[3]
break
else:
min_mix_freq = 0.0
mask = np.zeros((max_size, max_size), dtype=np.int64)
mask[:, 0] = 1
ids = [match_token_id or self.token_id]
sizes = [0, 0]
self._ravel(nodes, ids, mask, -1,
max_size=max_size,
max_length=max_length,
min_output_freq=min_output_freq,
min_input_freq=min_input_freq,
min_mix_freq=min_mix_freq,
sizes=sizes,
output_weight=output_weight,
mode=mode,
idx=idx)
size = len(ids)
mask = mask[:size, :size]
return ids, mask, sizes