annotated-transformer
annotated-transformer copied to clipboard
some questions about MultiHeadAtttention
class MultiHeadedAttention(nn.Module): def init(self, h, d_model, dropout=0.1): "Take in model size and number of heads." super(MultiHeadedAttention, self).init() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(nn.Linear(d_model, d_model), 4) self.attn = None ------------------------------------------------------------------ # this should be deleted? self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
"Implements Figure 2"
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.mul (query, key, value)) -----------------------------# self.mul should be self.linears?
]