minimind icon indicating copy to clipboard operation
minimind copied to clipboard

KV Cache 的实现为什么xq要拼接 zerors 矩阵

Open Zbaoli opened this issue 1 year ago • 1 comments

在 Attention 方法里涉及到 KV cache 的实现部分

past_key, past_value = past_kv
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
xv = torch.cat((past_value, self.wv(current_token)), dim=1)

为什么 xq 需要拼接 zeros 矩阵? 是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行; 例如 llama3 的实现就没有拼接 zeros 矩阵:

Zbaoli avatar Sep 13 '24 10:09 Zbaoli

在 Attention 方法里涉及到 KV cache 的实现部分

past_key, past_value = past_kv
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
xv = torch.cat((past_value, self.wv(current_token)), dim=1)

为什么 xq 需要拼接 zeros 矩阵? 是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行; 例如 llama3 的实现就没有拼接 zeros 矩阵:

llama3在推理时候的seqlen是1,它的generate函数每次只把current_token输入attention层计算。

# llama3 attention
def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
):
    # 推理时seqlen==1
    bsz, seqlen, _ = x.shape
    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

    xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
    xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
    # 到这里为止,xq, xk, xv的shape都是[bsz, 1, *, self.head_dim]
    # freqs_cis的输入也是cis[-1:, :] = [1, head_dim]
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
    # 在计算RoPE嵌入之前,q,k,v维度需要一致

    self.cache_k = self.cache_k.to(xq)
    self.cache_v = self.cache_v.to(xq)

    self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
    self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv

    keys = self.cache_k[:bsz, : start_pos + seqlen]
    values = self.cache_v[:bsz, : start_pos + seqlen]

    # repeat k/v heads if n_kv_heads < n_heads
    keys = repeat_kv(
        keys, self.n_rep
    )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
    values = repeat_kv(
        values, self.n_rep
    )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

    xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
    keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
    values = values.transpose(
        1, 2
    )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
  • 假设q_head=16, seqlen=20 q = (1, 16, 1, 32)
    k = (1, 16, 20, 32)
    $q \times k^T + \text{mask} = (1, 16, 1, 20) + (1, 16, 20, 20) = (1, 16, 20, 20)$
    v = (1, 16, 20, 32)

在计算时,仅使用当前 token 作为 $q$ 是没有问题的,无需拼接前 $n-1$ 个 token 的 $q$,因为在推理时,输入序列长度 $seqlen=1$。

与 LLaMA3 不同的是,modelgenerate 函数和训练类似,每次把长度为seqlen 的完整 token 输入到 attention 层进行计算。

当只计算当前 token 的 $xq$ 时,为了保证 $xk$, $xv$ 和 $xq$ 在RoPE编码及之前的计算中具有相同的 $seqlen$ 维度,需要将前 $n-1$ 个 token 的 $q$ 进行拼接。前 $n-1$ 个 token 的 $q$ 并不实际影响当前token的qk计算,所以用 zeros_like(xk[:, :-1, :]) 来代替。

的确,前者注意力的计算复杂度为 $o(N \times {dim_{head}})$ << 后者的 $o(N^2 \times {dim_{head}})$,所以generate函数推理时每次只使用当前 token 更高效。此外,KV_cache 没有参考 LLaMA3,最初只是随意写了一个潦草的cache变量。

刚刚看着改了一下,这是在现有方案上强行实现current_token低复杂度的修改版本:

def forward(
        self,
        x: torch.Tensor,
        pos_cis: torch.Tensor,
        use_kv_cache: bool = False,
        past_kv: Tuple[torch.Tensor] = None
):
    bsz, seqlen, _ = x.shape

    keys, values = None, None
    flag = 1
    # QKV
    # inference
    if use_kv_cache:
        current_token = x[:, -1:, :]

        if not past_kv:
            xq = self.wq(x)
            xk, xv = self.wk(x), self.wv(x)
            flag = 1
            past_kv = (xk, xv)
        else:
            past_key, past_value = past_kv
            xq = self.wq(current_token)
            xk = self.wk(current_token)
            xv = self.wv(current_token)
            keys = torch.cat((past_key, xk), dim=1)
            values = torch.cat((past_value, xv), dim=1)
            past_kv = (keys, values)
            flag = 2
    else:
        xq = self.wq(x)
        xk, xv = self.wk(x), self.wv(x)

    if flag == 2:
        xq = xq.view(bsz, 1, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
    else:
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

    if flag == 1:
        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, pos_cis)
    else:
        xq, xk = apply_rotary_emb(xq, xk, pos_cis[-1:, :])

    if flag == 2:
        past_key, past_value = past_kv
        keys = torch.cat((past_key[:, :-1, :], xk.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
        values = torch.cat((past_value[:, :-1, :], xv.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
        past_kv = (keys, values)
        keys = keys.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        values = values.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xk = keys
        xv = values

    # grouped multiquery attention: expand out keys and values
    xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
    xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

    # make heads into a batch dimension
    xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
    xk = xk.transpose(1, 2)
    xv = xv.transpose(1, 2)

    # manual implementation
    scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
    assert hasattr(self, 'mask')
    scores = scores + self.mask[:, :, :seqlen, :seqlen]  # (bs, n_local_heads, seqlen, cache_len + seqlen)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    scores = self.attn_dropout(scores)
    output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

    if flag == 2:
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
    else:
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

    # final projection into the residual stream
    output = self.wo(output)
    output = self.resid_dropout(output)
    return output, past_kv

暂时推理函数就先不做大改了。

欢迎继续交流指正。

jingyaogong avatar Sep 13 '24 13:09 jingyaogong

最新的代码里面关于自注意力模块部分是剔除了K_VCache的计算只做保存了吗?

Image

GaoYangGF avatar Apr 08 '25 04:04 GaoYangGF

最新的代码里面关于自注意力模块部分是剔除了K_VCache的计算只做保存了吗?

Image

没有剔除

推理过程,有kv_cahe的时候输入的x的序列长度是1,这个要清楚

if first_seq or not use_cache:
    out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
else:
    out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
                           start_pos=input_ids.shape[1] - 1, **args)

jingyaogong avatar Apr 08 '25 05:04 jingyaogong