Torch-KWT
Torch-KWT copied to clipboard
Wrong residual structure when using PostNorm
In file Torch-KWT/model/kwt.py, the PostNorm class is written as:
class PostNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.norm(self.fn(x, **kwargs))
And the residual structure for both PostNorm and PreNorm is written as:
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
I think it is only right when using PreNorm. For PostNorm, the structure is different from the original paper. Because in your code you add the original x after normalizing the output of x while it should be added before normalizing.