KV Cache 的实现为什么xq要拼接 zerors 矩阵
在 Attention 方法里涉及到 KV cache 的实现部分
past_key, past_value = past_kv
xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
xk = torch.cat((past_key, self.wk(current_token)), dim=1)
xv = torch.cat((past_value, self.wv(current_token)), dim=1)
为什么 xq 需要拼接 zeros 矩阵? 是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行; 例如 llama3 的实现就没有拼接 zeros 矩阵:
在 Attention 方法里涉及到 KV cache 的实现部分
past_key, past_value = past_kv xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1) xk = torch.cat((past_key, self.wk(current_token)), dim=1) xv = torch.cat((past_value, self.wv(current_token)), dim=1)为什么 xq 需要拼接 zeros 矩阵? 是为了保证 xq 的 slen 维度跟 xk 与 xv 一样吗,但是这样会增加计算量,而且即使维度不同,后面的运算应该也可以正常运行; 例如 llama3 的实现就没有拼接 zeros 矩阵:
llama3在推理时候的seqlen是1,它的generate函数每次只把current_token输入attention层计算。
# llama3 attention
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
# 推理时seqlen==1
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 到这里为止,xq, xk, xv的shape都是[bsz, 1, *, self.head_dim]
# freqs_cis的输入也是cis[-1:, :] = [1, head_dim]
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 在计算RoPE嵌入之前,q,k,v维度需要一致
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
- 假设
q_head=16,seqlen=20q = (1, 16, 1, 32)
k = (1, 16, 20, 32)
$q \times k^T + \text{mask} = (1, 16, 1, 20) + (1, 16, 20, 20) = (1, 16, 20, 20)$
v = (1, 16, 20, 32)
在计算时,仅使用当前 token 作为 $q$ 是没有问题的,无需拼接前 $n-1$ 个 token 的 $q$,因为在推理时,输入序列长度 $seqlen=1$。
与 LLaMA3 不同的是,model 的 generate 函数和训练类似,每次把长度为seqlen 的完整 token 输入到 attention 层进行计算。
当只计算当前 token 的 $xq$ 时,为了保证 $xk$, $xv$ 和 $xq$ 在RoPE编码及之前的计算中具有相同的 $seqlen$ 维度,需要将前 $n-1$ 个 token 的 $q$ 进行拼接。前 $n-1$ 个 token 的 $q$ 并不实际影响当前token的qk计算,所以用 zeros_like(xk[:, :-1, :]) 来代替。
的确,前者注意力的计算复杂度为 $o(N \times {dim_{head}})$ << 后者的 $o(N^2 \times {dim_{head}})$,所以generate函数推理时每次只使用当前 token 更高效。此外,KV_cache 没有参考 LLaMA3,最初只是随意写了一个潦草的cache变量。
刚刚看着改了一下,这是在现有方案上强行实现current_token低复杂度的修改版本:
def forward(
self,
x: torch.Tensor,
pos_cis: torch.Tensor,
use_kv_cache: bool = False,
past_kv: Tuple[torch.Tensor] = None
):
bsz, seqlen, _ = x.shape
keys, values = None, None
flag = 1
# QKV
# inference
if use_kv_cache:
current_token = x[:, -1:, :]
if not past_kv:
xq = self.wq(x)
xk, xv = self.wk(x), self.wv(x)
flag = 1
past_kv = (xk, xv)
else:
past_key, past_value = past_kv
xq = self.wq(current_token)
xk = self.wk(current_token)
xv = self.wv(current_token)
keys = torch.cat((past_key, xk), dim=1)
values = torch.cat((past_value, xv), dim=1)
past_kv = (keys, values)
flag = 2
else:
xq = self.wq(x)
xk, xv = self.wk(x), self.wv(x)
if flag == 2:
xq = xq.view(bsz, 1, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, 1, self.n_local_kv_heads, self.head_dim)
else:
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
if flag == 1:
# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
else:
xq, xk = apply_rotary_emb(xq, xk, pos_cis[-1:, :])
if flag == 2:
past_key, past_value = past_kv
keys = torch.cat((past_key[:, :-1, :], xk.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
values = torch.cat((past_value[:, :-1, :], xv.view(bsz, 1, self.n_local_kv_heads * self.head_dim)), dim=1)
past_kv = (keys, values)
keys = keys.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
values = values.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xk = keys
xv = values
# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
# make heads into a batch dimension
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
if flag == 2:
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# final projection into the residual stream
output = self.wo(output)
output = self.resid_dropout(output)
return output, past_kv
暂时推理函数就先不做大改了。
欢迎继续交流指正。
最新的代码里面关于自注意力模块部分是剔除了K_VCache的计算只做保存了吗?
最新的代码里面关于自注意力模块部分是剔除了K_VCache的计算只做保存了吗?
![]()
没有剔除
推理过程,有kv_cahe的时候输入的x的序列长度是1,这个要清楚
if first_seq or not use_cache:
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
else:
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
start_pos=input_ids.shape[1] - 1, **args)