annotated-transformer icon indicating copy to clipboard operation
annotated-transformer copied to clipboard

some questions about MultiHeadAtttention

Open SteveBetter opened this issue 3 years ago • 0 comments

class MultiHeadedAttention(nn.Module): def init(self, h, d_model, dropout=0.1): "Take in model size and number of heads." super(MultiHeadedAttention, self).init() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(nn.Linear(d_model, d_model), 4) self.attn = None ------------------------------------------------------------------ # this should be deleted? self.dropout = nn.Dropout(p=dropout)

def forward(self, query, key, value, mask=None):
    "Implements Figure 2"
    if mask is not None:
        # Same mask applied to all h heads.
        mask = mask.unsqueeze(1)
    nbatches = query.size(0)

    # 1) Do all the linear projections in batch from d_model => h x d_k
    query, key, value = [
        lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        for lin, x in zip(self.mul (query, key, value))      -----------------------------# self.mul should be self.linears?
    ]

SteveBetter avatar Jun 05 '22 04:06 SteveBetter