llm_interview_note icon indicating copy to clipboard operation
llm_interview_note copied to clipboard

MHA_MQA_GQA代码问题

Open ZhuJiaqi9905 opened this issue 1 year ago • 0 comments

在attention的计算时的注意力拼接部分。output的维度是想从[batch_size, num_heads, seq_len, head_dim]变为[batch_size, seq_len, model_dim]。因此感觉transpose有误。应该从

## 对注意力输出进行拼接
output = (
    output.transpose(-1, -2)
    .contiguous()
    .view(batch_size, -1, self.head_dim * self.num_heads)
)

改为

## 对注意力输出进行拼接
output = (
    output.transpose(-2, -3)
    .contiguous()
    .view(batch_size, -1, self.head_dim * self.num_heads)
)

ZhuJiaqi9905 avatar Aug 19 '24 01:08 ZhuJiaqi9905