Easy-Transformer
Easy-Transformer copied to clipboard
Fix to include ln_final.w in RMSNorm hook
Description
Before fix, model_cache["ln_final.hook_normalized"]
will only return the RMS normalized hidden states without multiplying the final_ln
weight. This might contractict with the design of this hook. I followed the implementation in LayerNorm hook to include self.w
into the hook_normalized
hookpoint.
I have currently found no issues related to this. Hope this helps :)
Before Fix
def forward(
self, x: Float[torch.Tensor, "batch pos length"]
) -> Float[torch.Tensor, "batch pos length"]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
)
x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
return x * self.w
After Fix
def forward(
self, x: Float[torch.Tensor, "batch pos length"]
) -> Float[torch.Tensor, "batch pos length"]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
)
x = x / scale # [batch, pos, length]
return self.hook_normalized(x * self.w).to(self.cfg.dtype)
Type of change
- [x] Bug fix (non-breaking change which fixes an issue)
Checklist:
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
- [x] I have not rewritten tests relating to key interfaces which would affect backward compatibility
I'm to pushback, but currently think that hook_normalized is working as intended, because it's invariant between folding layer norm
On Sat, 26 Aug 2023, 5:15 pm Haoyan Luo, @.***> wrote:
Description
Before fix, model_cache["ln_final.hook_normalized"] will only return the RMS normalized hidden states without multiplying the final_ln weight. This might contractict with the design of this hook. I followed the implementation in LayerNorm hook to include self.w into the hook_normalized hookpoint.
I have currently found no issues related to this. Hope this helps :) Before Fix
def forward( self, x: Float[torch.Tensor, "batch pos length"] ) -> Float[torch.Tensor, "batch pos length"]: if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() ) x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] return x * self.w
After Fix
def forward( self, x: Float[torch.Tensor, "batch pos length"] ) -> Float[torch.Tensor, "batch pos length"]: if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() ) x = x / scale # [batch, pos, length] return self.hook_normalized(x * self.w).to(self.cfg.dtype)
Type of change
- Bug fix (non-breaking change which fixes an issue)
Checklist:
- I have commented my code, particularly in hard-to-understand areas
- I have made corresponding changes to the documentation
- My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature works
- New and existing unit tests pass locally with my changes
- I have not rewritten tests relating to key interfaces which would affect backward compatibility
You can view, comment on, or merge this pull request online at:
https://github.com/neelnanda-io/TransformerLens/pull/375 Commit Summary
- af1b2b6 https://github.com/neelnanda-io/TransformerLens/pull/375/commits/af1b2b65afc65c2bc536ce95d4012667baa30afe Fix to include ln_final.w in RMSNorm hook
File Changes
(1 file https://github.com/neelnanda-io/TransformerLens/pull/375/files)
- M transformer_lens/components.py https://github.com/neelnanda-io/TransformerLens/pull/375/files#diff-d3f09bb699c3b05afb6d0cb1102d441eefc4d0f6c2aabdf96ff7d888c43c60aa (4)
Patch Links:
- https://github.com/neelnanda-io/TransformerLens/pull/375.patch
- https://github.com/neelnanda-io/TransformerLens/pull/375.diff
— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/pull/375, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKNFTZGT6DUBHOLP3A3XXIOLFANCNFSM6AAAAAA37V57BU . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Yes, i believe the hook_normalized works fine. But I think the problem is when retrieving the final normalized output after RMSNorm, say calling llama_logits, llama_cache = model.run_with_cache(llama_tokens)
and to get llama_cache["ln_final.hook_normalized"]
, it would return x / scale
instead of x / scale * self.w
since self.w
is not included in hook_normalized.
@clarenceluo78 I think the key point here is that a common thing done in TransformerLens is folding the layer norm weights into the next linear layer. See https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln for details.
Therefore excluding it from hook_normalized seems to make sense as we want to be able to extract something that is invariant to whether or not fold_ln is enabled.
Let me know if you disagree however.
I think that this should not be merged, but we should reopen and merge this PR that fixes the inconsistency that @clarenceluo78 noticed. Unless @neelnanda-io can explain why the behavior of the the hook_normalized caching in LayerNorm is correct?
I am going to go ahead and close this one. If anything further needs to be done with this, let's discuss it on slack.