vit-pytorch
vit-pytorch copied to clipboard
How to use Multi-Head Attention in ViT
I need help to understand the multihead attention in ViT.
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
I'm confused with the dim
and dim_head
here.
For example, I want to input an image (PyTorch) as batchsize, channels, width, height = x.size()
What should be the value of dim
? dim = channels
?
Also, in def forward(self, x):
, what are the representations of b n (h d)?
dim(e.g., 1024) is the final output dimsion of multi attention module. dim_head = dim // head_num, when head_num = 1, dim_head is equal to dim.
b: batch size n: num_patches + 1 h: head_num( i.e., heads) d: dim_head
I hope it works for you.