Torch-KWT icon indicating copy to clipboard operation
Torch-KWT copied to clipboard

Wrong residual structure when using PostNorm

Open jackykj opened this issue 1 year ago • 0 comments

In file Torch-KWT/model/kwt.py, the PostNorm class is written as:

class PostNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.norm(self.fn(x, **kwargs))

And the residual structure for both PostNorm and PreNorm is written as:

    for attn, ff in self.layers:
        x = attn(x) + x
        x = ff(x) + x

I think it is only right when using PreNorm. For PostNorm, the structure is different from the original paper. Because in your code you add the original x after normalizing the output of x while it should be added before normalizing.

jackykj avatar Jul 13 '23 07:07 jackykj