bert4keras icon indicating copy to clipboard operation
bert4keras copied to clipboard

在beam search生成的时候如何输出top-k的生成结果?

Open RefluxNing opened this issue 4 years ago • 3 comments

RefluxNing avatar Sep 17 '21 02:09 RefluxNing

可以通过自行修改beam_search的源码输出。

bojone avatar Sep 26 '21 14:09 bojone

这是我写的一个top_k、温度参数、重复惩罚、批量解码的beam search、batch_encoder_output的shape是[batch_size, max_seq_len, vocab_size]、可以参考下

def beam_search(self, batch_encoder_out, top_k: int = 1, temperature: float = 1.,
                repetition_penalty: float = 2.) -> Generator:
    """
    批量beam search解码
    参考bert4keras.snippets.AutoRegressiveDecoder
    Args:
        batch_encoder_out:
        top_k: beam_size
        temperature: token prob平滑
        repetition_penalty: 重复惩罚

    Returns:

    """
    batch_size = batch_encoder_out.shape[0]
    vocab_size = self.tokenizer.vocab_size

    batch_output_ids = np.repeat(self.first_decoder_output_ids, batch_size * top_k, axis=0)
    batch_output_scores = np.zeros((batch_size * top_k, 1))
    batch_encoder_out = np.repeat(batch_encoder_out, top_k, axis=0)

    for step in range(self.max_decoder_length):
        logger.debug(f"beam search step {step}")
        # 计算下一个token prob (batch_size * top_k, vocab_size)
        batch_step_logits = self.model.decoder_last.predict([batch_encoder_out, batch_output_ids])

        # 重复惩罚
        rep_penalty = np.zeros_like(batch_step_logits)
        row = [i for i in range(batch_output_ids.shape[0]) for _ in batch_output_ids[i, :]]
        col = [e for row in batch_output_ids for e in row]
        rep_penalty[row, col] = 1.
        rep_penalty[(batch_step_logits > 0) * (rep_penalty == 1)] /= repetition_penalty
        rep_penalty[(batch_step_logits < 0) * (rep_penalty == 1)] *= repetition_penalty
        rep_penalty[rep_penalty == 0.] = 1.
        batch_step_logits *= rep_penalty

        # TODO: 计算softmax,需要考虑mask,但是终止序列看见eos后停止输出后续token,暂时可以不处理
        batch_step_scores = np.exp(batch_step_logits) / np.sum(np.exp(batch_step_logits), axis=-1, keepdims=True)

        # temperature 平滑
        batch_step_scores = np.power(batch_step_scores, 1.0 / temperature)
        batch_step_scores = batch_step_scores / batch_step_scores.sum(axis=-1, keepdims=True)

        # (batch_size * top_k, vocab_size)
        batch_step_scores += batch_output_scores.reshape((-1, 1))

        # (batch_size, top_k * vocab_size)
        batch_step_scores = batch_step_scores.reshape((batch_size, -1))
        # 如果是起始步骤,从每个encoder_out的第一行取即可(确保第一步是top_k而非top_1)
        if step == 0:
            batch_step_scores = batch_step_scores[:, :vocab_size]

        batch_indices = batch_step_scores.reshape((batch_size, -1)).argpartition(-top_k, axis=1)[:, -top_k:]

        batch_row_indices = (batch_indices // vocab_size + top_k * np.ones_like(batch_indices) * np.arange(
            batch_size).reshape((-1, 1))).reshape((-1, 1))
        batch_col_indices = (batch_indices % vocab_size).reshape((-1, 1))

        # 调整 top_k token_ids (batch_size * top_k, step + 1)
        batch_output_ids = np.concatenate([batch_output_ids[batch_row_indices.reshape(-1)], batch_col_indices], 1)

        # 调整 top_k scores (batch_size * top_k, step + 1)
        batch_output_scores = np.take_along_axis(batch_step_scores.reshape((batch_size, -1)), batch_indices, axis=1)

        # # 不到最短长度,继续生成
        # if batch_output_ids.shape[0] < self.min_decoder_length * 0 + 2:
        #     continue

        # 生成完成的seq提前终止
        # 终止序列
        ended_seq_index = np.equal(batch_output_ids, 1).sum(axis=1, keepdims=True).reshape((batch_size, -1))
        id_best_seq = batch_output_scores.max(axis=1, keepdims=True) == batch_output_scores
        ended_seq = (ended_seq_index & id_best_seq).sum(axis=1)
        # TODO: early stopping已完成的序列,现在是所有序列已完成退出
        if np.all(ended_seq):
            return

        # batch_output_ids[id_best_seq.reshape((batch_size * top_k))]
        yield batch_output_ids[id_best_seq.reshape((batch_size * top_k))], batch_output_scores.max(axis=1)

i4never avatar Sep 28 '21 07:09 i4never

这是我写的一个top_k、温度参数、重复惩罚、批量解码的beam search、batch_encoder_output的shape是[batch_size, max_seq_len, vocab_size]、可以参考下

def beam_search(self, batch_encoder_out, top_k: int = 1, temperature: float = 1.,
                repetition_penalty: float = 2.) -> Generator:
    """
    批量beam search解码
    参考bert4keras.snippets.AutoRegressiveDecoder
    Args:
        batch_encoder_out:
        top_k: beam_size
        temperature: token prob平滑
        repetition_penalty: 重复惩罚

    Returns:

    """
    batch_size = batch_encoder_out.shape[0]
    vocab_size = self.tokenizer.vocab_size

    batch_output_ids = np.repeat(self.first_decoder_output_ids, batch_size * top_k, axis=0)
    batch_output_scores = np.zeros((batch_size * top_k, 1))
    batch_encoder_out = np.repeat(batch_encoder_out, top_k, axis=0)

    for step in range(self.max_decoder_length):
        logger.debug(f"beam search step {step}")
        # 计算下一个token prob (batch_size * top_k, vocab_size)
        batch_step_logits = self.model.decoder_last.predict([batch_encoder_out, batch_output_ids])

        # 重复惩罚
        rep_penalty = np.zeros_like(batch_step_logits)
        row = [i for i in range(batch_output_ids.shape[0]) for _ in batch_output_ids[i, :]]
        col = [e for row in batch_output_ids for e in row]
        rep_penalty[row, col] = 1.
        rep_penalty[(batch_step_logits > 0) * (rep_penalty == 1)] /= repetition_penalty
        rep_penalty[(batch_step_logits < 0) * (rep_penalty == 1)] *= repetition_penalty
        rep_penalty[rep_penalty == 0.] = 1.
        batch_step_logits *= rep_penalty

        # TODO: 计算softmax,需要考虑mask,但是终止序列看见eos后停止输出后续token,暂时可以不处理
        batch_step_scores = np.exp(batch_step_logits) / np.sum(np.exp(batch_step_logits), axis=-1, keepdims=True)

        # temperature 平滑
        batch_step_scores = np.power(batch_step_scores, 1.0 / temperature)
        batch_step_scores = batch_step_scores / batch_step_scores.sum(axis=-1, keepdims=True)

        # (batch_size * top_k, vocab_size)
        batch_step_scores += batch_output_scores.reshape((-1, 1))

        # (batch_size, top_k * vocab_size)
        batch_step_scores = batch_step_scores.reshape((batch_size, -1))
        # 如果是起始步骤,从每个encoder_out的第一行取即可(确保第一步是top_k而非top_1)
        if step == 0:
            batch_step_scores = batch_step_scores[:, :vocab_size]

        batch_indices = batch_step_scores.reshape((batch_size, -1)).argpartition(-top_k, axis=1)[:, -top_k:]

        batch_row_indices = (batch_indices // vocab_size + top_k * np.ones_like(batch_indices) * np.arange(
            batch_size).reshape((-1, 1))).reshape((-1, 1))
        batch_col_indices = (batch_indices % vocab_size).reshape((-1, 1))

        # 调整 top_k token_ids (batch_size * top_k, step + 1)
        batch_output_ids = np.concatenate([batch_output_ids[batch_row_indices.reshape(-1)], batch_col_indices], 1)

        # 调整 top_k scores (batch_size * top_k, step + 1)
        batch_output_scores = np.take_along_axis(batch_step_scores.reshape((batch_size, -1)), batch_indices, axis=1)

        # # 不到最短长度,继续生成
        # if batch_output_ids.shape[0] < self.min_decoder_length * 0 + 2:
        #     continue

        # 生成完成的seq提前终止
        # 终止序列
        ended_seq_index = np.equal(batch_output_ids, 1).sum(axis=1, keepdims=True).reshape((batch_size, -1))
        id_best_seq = batch_output_scores.max(axis=1, keepdims=True) == batch_output_scores
        ended_seq = (ended_seq_index & id_best_seq).sum(axis=1)
        # TODO: early stopping已完成的序列,现在是所有序列已完成退出
        if np.all(ended_seq):
            return

        # batch_output_ids[id_best_seq.reshape((batch_size * top_k))]
        yield batch_output_ids[id_best_seq.reshape((batch_size * top_k))], batch_output_scores.max(axis=1)

有相对完整的代码吗

SunYanCN avatar Sep 30 '21 03:09 SunYanCN

@SunYanCN 没有:)

i4never avatar Sep 30 '21 04:09 i4never