bert4keras
bert4keras copied to clipboard
在beam search生成的时候如何输出top-k的生成结果?
可以通过自行修改beam_search的源码输出。
这是我写的一个top_k、温度参数、重复惩罚、批量解码的beam search、batch_encoder_output的shape是[batch_size, max_seq_len, vocab_size]、可以参考下
def beam_search(self, batch_encoder_out, top_k: int = 1, temperature: float = 1.,
repetition_penalty: float = 2.) -> Generator:
"""
批量beam search解码
参考bert4keras.snippets.AutoRegressiveDecoder
Args:
batch_encoder_out:
top_k: beam_size
temperature: token prob平滑
repetition_penalty: 重复惩罚
Returns:
"""
batch_size = batch_encoder_out.shape[0]
vocab_size = self.tokenizer.vocab_size
batch_output_ids = np.repeat(self.first_decoder_output_ids, batch_size * top_k, axis=0)
batch_output_scores = np.zeros((batch_size * top_k, 1))
batch_encoder_out = np.repeat(batch_encoder_out, top_k, axis=0)
for step in range(self.max_decoder_length):
logger.debug(f"beam search step {step}")
# 计算下一个token prob (batch_size * top_k, vocab_size)
batch_step_logits = self.model.decoder_last.predict([batch_encoder_out, batch_output_ids])
# 重复惩罚
rep_penalty = np.zeros_like(batch_step_logits)
row = [i for i in range(batch_output_ids.shape[0]) for _ in batch_output_ids[i, :]]
col = [e for row in batch_output_ids for e in row]
rep_penalty[row, col] = 1.
rep_penalty[(batch_step_logits > 0) * (rep_penalty == 1)] /= repetition_penalty
rep_penalty[(batch_step_logits < 0) * (rep_penalty == 1)] *= repetition_penalty
rep_penalty[rep_penalty == 0.] = 1.
batch_step_logits *= rep_penalty
# TODO: 计算softmax,需要考虑mask,但是终止序列看见eos后停止输出后续token,暂时可以不处理
batch_step_scores = np.exp(batch_step_logits) / np.sum(np.exp(batch_step_logits), axis=-1, keepdims=True)
# temperature 平滑
batch_step_scores = np.power(batch_step_scores, 1.0 / temperature)
batch_step_scores = batch_step_scores / batch_step_scores.sum(axis=-1, keepdims=True)
# (batch_size * top_k, vocab_size)
batch_step_scores += batch_output_scores.reshape((-1, 1))
# (batch_size, top_k * vocab_size)
batch_step_scores = batch_step_scores.reshape((batch_size, -1))
# 如果是起始步骤,从每个encoder_out的第一行取即可(确保第一步是top_k而非top_1)
if step == 0:
batch_step_scores = batch_step_scores[:, :vocab_size]
batch_indices = batch_step_scores.reshape((batch_size, -1)).argpartition(-top_k, axis=1)[:, -top_k:]
batch_row_indices = (batch_indices // vocab_size + top_k * np.ones_like(batch_indices) * np.arange(
batch_size).reshape((-1, 1))).reshape((-1, 1))
batch_col_indices = (batch_indices % vocab_size).reshape((-1, 1))
# 调整 top_k token_ids (batch_size * top_k, step + 1)
batch_output_ids = np.concatenate([batch_output_ids[batch_row_indices.reshape(-1)], batch_col_indices], 1)
# 调整 top_k scores (batch_size * top_k, step + 1)
batch_output_scores = np.take_along_axis(batch_step_scores.reshape((batch_size, -1)), batch_indices, axis=1)
# # 不到最短长度,继续生成
# if batch_output_ids.shape[0] < self.min_decoder_length * 0 + 2:
# continue
# 生成完成的seq提前终止
# 终止序列
ended_seq_index = np.equal(batch_output_ids, 1).sum(axis=1, keepdims=True).reshape((batch_size, -1))
id_best_seq = batch_output_scores.max(axis=1, keepdims=True) == batch_output_scores
ended_seq = (ended_seq_index & id_best_seq).sum(axis=1)
# TODO: early stopping已完成的序列,现在是所有序列已完成退出
if np.all(ended_seq):
return
# batch_output_ids[id_best_seq.reshape((batch_size * top_k))]
yield batch_output_ids[id_best_seq.reshape((batch_size * top_k))], batch_output_scores.max(axis=1)
这是我写的一个top_k、温度参数、重复惩罚、批量解码的beam search、batch_encoder_output的shape是[batch_size, max_seq_len, vocab_size]、可以参考下
def beam_search(self, batch_encoder_out, top_k: int = 1, temperature: float = 1., repetition_penalty: float = 2.) -> Generator: """ 批量beam search解码 参考bert4keras.snippets.AutoRegressiveDecoder Args: batch_encoder_out: top_k: beam_size temperature: token prob平滑 repetition_penalty: 重复惩罚 Returns: """ batch_size = batch_encoder_out.shape[0] vocab_size = self.tokenizer.vocab_size batch_output_ids = np.repeat(self.first_decoder_output_ids, batch_size * top_k, axis=0) batch_output_scores = np.zeros((batch_size * top_k, 1)) batch_encoder_out = np.repeat(batch_encoder_out, top_k, axis=0) for step in range(self.max_decoder_length): logger.debug(f"beam search step {step}") # 计算下一个token prob (batch_size * top_k, vocab_size) batch_step_logits = self.model.decoder_last.predict([batch_encoder_out, batch_output_ids]) # 重复惩罚 rep_penalty = np.zeros_like(batch_step_logits) row = [i for i in range(batch_output_ids.shape[0]) for _ in batch_output_ids[i, :]] col = [e for row in batch_output_ids for e in row] rep_penalty[row, col] = 1. rep_penalty[(batch_step_logits > 0) * (rep_penalty == 1)] /= repetition_penalty rep_penalty[(batch_step_logits < 0) * (rep_penalty == 1)] *= repetition_penalty rep_penalty[rep_penalty == 0.] = 1. batch_step_logits *= rep_penalty # TODO: 计算softmax,需要考虑mask,但是终止序列看见eos后停止输出后续token,暂时可以不处理 batch_step_scores = np.exp(batch_step_logits) / np.sum(np.exp(batch_step_logits), axis=-1, keepdims=True) # temperature 平滑 batch_step_scores = np.power(batch_step_scores, 1.0 / temperature) batch_step_scores = batch_step_scores / batch_step_scores.sum(axis=-1, keepdims=True) # (batch_size * top_k, vocab_size) batch_step_scores += batch_output_scores.reshape((-1, 1)) # (batch_size, top_k * vocab_size) batch_step_scores = batch_step_scores.reshape((batch_size, -1)) # 如果是起始步骤,从每个encoder_out的第一行取即可(确保第一步是top_k而非top_1) if step == 0: batch_step_scores = batch_step_scores[:, :vocab_size] batch_indices = batch_step_scores.reshape((batch_size, -1)).argpartition(-top_k, axis=1)[:, -top_k:] batch_row_indices = (batch_indices // vocab_size + top_k * np.ones_like(batch_indices) * np.arange( batch_size).reshape((-1, 1))).reshape((-1, 1)) batch_col_indices = (batch_indices % vocab_size).reshape((-1, 1)) # 调整 top_k token_ids (batch_size * top_k, step + 1) batch_output_ids = np.concatenate([batch_output_ids[batch_row_indices.reshape(-1)], batch_col_indices], 1) # 调整 top_k scores (batch_size * top_k, step + 1) batch_output_scores = np.take_along_axis(batch_step_scores.reshape((batch_size, -1)), batch_indices, axis=1) # # 不到最短长度,继续生成 # if batch_output_ids.shape[0] < self.min_decoder_length * 0 + 2: # continue # 生成完成的seq提前终止 # 终止序列 ended_seq_index = np.equal(batch_output_ids, 1).sum(axis=1, keepdims=True).reshape((batch_size, -1)) id_best_seq = batch_output_scores.max(axis=1, keepdims=True) == batch_output_scores ended_seq = (ended_seq_index & id_best_seq).sum(axis=1) # TODO: early stopping已完成的序列,现在是所有序列已完成退出 if np.all(ended_seq): return # batch_output_ids[id_best_seq.reshape((batch_size * top_k))] yield batch_output_ids[id_best_seq.reshape((batch_size * top_k))], batch_output_scores.max(axis=1)
有相对完整的代码吗
@SunYanCN 没有:)