How-to-use-Transformers icon indicating copy to clipboard operation
How-to-use-Transformers copied to clipboard

关于第三章: 注意力机制 实现的问题

Open Chaochao2020 opened this issue 1 year ago • 2 comments

在下面的代码中, 我觉得应该表明为什么 Q, K, V 向量序列是等于 inputs_embeds 的, 我理解的是注意力机制中的 QKV 是 embedding 与 W_Q 和 W_K , W_V 这三个矩阵相乘得到的, 这三个矩阵也是超参数, 而下面的代码是好像默认 这三个矩阵是单位矩阵. `import torch from math import sqrt

Q = K = V = inputs_embeds dim_k = K.size(-1) scores = torch.bmm(Q, K.transpose(1,2)) / sqrt(dim_k) print(scores.size())`

此外 dim_k = K.size(-1) 和下面封装的函数中不一致, 上面的 dim_k = K.size(-1), 而下面的 dim_k = query.size(-1)

`import torch import torch.nn.functional as F from math import sqrt

def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None): dim_k = query.size(-1) scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k) if query_mask is not None and key_mask is not None: mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1)) if mask is not None: scores = scores.masked_fill(mask == 0, -float("inf")) weights = F.softmax(scores, dim=-1) return torch.bmm(weights, value)`

Chaochao2020 avatar Oct 13 '23 11:10 Chaochao2020

在下面的代码中, 我觉得应该表明为什么 Q, K, V 向量序列是等于 inputs_embeds 的, 我理解的是注意力机制中的 QKV 是 embedding 与 W_Q 和 W_K , W_V 这三个矩阵相乘得到的, 这三个矩阵也是超参数, 而下面的代码是好像默认 这三个矩阵是单位矩阵. `import torch from math import sqrt

Q = K = V = inputs_embeds dim_k = K.size(-1) scores = torch.bmm(Q, K.transpose(1,2)) / sqrt(dim_k) print(scores.size())`

此外 dim_k = K.size(-1) 和下面封装的函数中不一致, 上面的 dim_k = K.size(-1), 而下面的 dim_k = query.size(-1)

`import torch import torch.nn.functional as F from math import sqrt

def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None): dim_k = query.size(-1) scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k) if query_mask is not None and key_mask is not None: mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1)) if mask is not None: scores = scores.masked_fill(mask == 0, -float("inf")) weights = F.softmax(scores, dim=-1) return torch.bmm(weights, value)`

  1. Because it is self-attention, Q=K=V
  2. If K.size(-1) != query.size(-1), how can the matrixes be multipled?

Melmaphother avatar Mar 25 '24 13:03 Melmaphother

谢谢

Chaochao2020 avatar Mar 26 '24 05:03 Chaochao2020