CRATE icon indicating copy to clipboard operation
CRATE copied to clipboard

关于attention中部分代码的问题

Open 01vanilla opened this issue 2 years ago • 0 comments

在attention代码中,我发现有一个名为to_out的操作,我无法理解这个操作是用来实现什么功能的 具体代码为: class Attention(nn.Module): def init(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().init() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.qkv = nn.Linear(dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity()

def forward(self, x):
    w = rearrange(self.qkv(x), 'b n (h d) -> b h n d', h = self.heads)
    dots = torch.matmul(w, w.transpose(-1, -2)) * self.scale
    attn = self.attend(dots)
    attn = self.dropout(attn)
    out = torch.matmul(attn, w)
    out = rearrange(out, 'b h n d -> b n (h d)')
    return self.to_out(out)

在forward得到结果后,最后输出是使用了to_out()操作,但是我在对应的MSSA部分没有找到相应的理论依据,请问可以麻烦解释一下吗

同时,在MSSA模块之前的LayerNorm和ISTA之前的LayerNorm是在代码中的哪部分实现的呢,我没有找到相应的代码

01vanilla avatar Dec 09 '23 11:12 01vanilla