PyTorch-Pretrained-ViT
PyTorch-Pretrained-ViT copied to clipboard
Multi head uses just one set of Q, K, V?
In transformer.py, in class MultiHeadedSelfAttention() we have the var declaration:
self.proj_q = nn.Linear(dim, dim)
self.proj_k = nn.Linear(dim, dim)
self.proj_v = nn.Linear(dim, dim)
but wasn't suposed to be Q, K and V an independent trainable matrix per head? E.g. if num_head = 12, wasn't that suposed to be like:
set = []
for i in range(12):
set.append([nn.Linear(dim, dim), nn.Linear(dim, dim), nn.Linear(dim, dim)])
Regards!