FNet-pytorch
FNet-pytorch copied to clipboard
Layer Normalization
Hi, thanks for a great implementation!
I wanted to clarify one thing that mismatches with the code, proposed in the article itself. In your code, you pre-normalize inputs, so that they are passed through LayerNorm before FFT. In the code, presented in the article, they have:
class FNetEncoderBlock ( nn . Module ) :
30 f o u r i e r _ l a y e r : Fou rie rT ran sfo rmLa ye r
31 f f _ l a y e r : FeedForwardLayer
32
33 @nn. compact
34 def _ _ c a l l _ _ ( s e l f , x , d e t e r m i n i s t i c ) :
35 m i x i n g _ o ut p ut = s e l f . f o u r i e r _ l a y e r ( x )
36 x = nn . LayerNorm (1 e−12 , name=" mixing_laye r_no rm " ) ( x + &
m i x i n g _ o ut p ut )
37 fe ed _fo rw a rd _o utp ut = s e l f . f f _ l a y e r ( x , d e t e r m i n i s t i c )
38 r e t u r n nn . LayerNorm (
39 1e−12 , name=" output_la ye r_no rm " ) ( x + fee d_fo rwa rd _outp ut )
which in my view is done in the opposite order. Am I mistaken or is it indeed a bug?
I see this code is damaged. Here is the image (A.5 in the paper):

A similar question regards dropout in the FeedForward layer. You have it added twice, while in the paper they add it only in the end:

@Aktsvigun you can checkout our repo https://github.com/erksch/fnet-pytorch. We reimplemented the architecture precisely to such a degree that we can even use the official checkpoints (converted from Jax to PyTorch).