[Bug Report] Layer norm folding not properly implemented for BertBlock
Describe the bug The layer norm folding is currently only implemented for the default TransformerBlock. If you use BertBlock (the layernorm is after the attention and mlp layers, not before), then the LayerNorm still gets folded into W_Q, W_K, and W_V, which is incorrect.
I believe the LayerNorm should instead be folded into W_O for attention and fc2 for the MLP but please check this.
Additional context Discovered the bug when adapting the repo to CLIP. Many implementations of CLIP use BertBlock, not TransformerBlock.
Checklist
- [X ] I have checked that there is no similar issue in the repo (required)
This may be of interest to @jbloomAus and @rusheb
Ah, no, LayerNorm should not be folded at all. You cannot fold it into W_O, because that would change the norm of the output of the layer and thus the LayerNorm scale. I can't think of any way to do LayerNorm folding for Bert, unfortunately
On Sun, 11 Feb 2024, 4:13 pm Sonia Joseph, @.***> wrote:
This may be of interest to @jbloomAus https://github.com/jbloomAus and @rusheb https://github.com/rusheb
— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/509#issuecomment-1937925922, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKILXE2ZASIL427R223YTFNDRAVCNFSM6AAAAABDD5VYJWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMZXHEZDKOJSGI . You are receiving this because you are subscribed to this thread.Message ID: @.***>