llm_interview_note
llm_interview_note copied to clipboard
MHA_MQA_GQA代码问题
在attention的计算时的注意力拼接部分。output的维度是想从[batch_size, num_heads, seq_len, head_dim]变为[batch_size, seq_len, model_dim]。因此感觉transpose有误。应该从
## 对注意力输出进行拼接
output = (
output.transpose(-1, -2)
.contiguous()
.view(batch_size, -1, self.head_dim * self.num_heads)
)
改为
## 对注意力输出进行拼接
output = (
output.transpose(-2, -3)
.contiguous()
.view(batch_size, -1, self.head_dim * self.num_heads)
)