Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

Fix to include ln_final.w in RMSNorm hook

Open clarenceluo78 opened this issue 1 year ago • 4 comments

Description

Before fix, model_cache["ln_final.hook_normalized"] will only return the RMS normalized hidden states without multiplying the final_ln weight. This might contractict with the design of this hook. I followed the implementation in LayerNorm hook to include self.w into the hook_normalized hookpoint.

I have currently found no issues related to this. Hope this helps :)

Before Fix

def forward(
        self, x: Float[torch.Tensor, "batch pos length"]
    ) -> Float[torch.Tensor, "batch pos length"]:
        if self.cfg.dtype not in [torch.float32, torch.float64]:
            x = x.to(torch.float32)

        scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
            (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
        )
        x = self.hook_normalized(x / scale).to(self.cfg.dtype)  # [batch, pos, length]
        return x * self.w

After Fix

def forward(
        self, x: Float[torch.Tensor, "batch pos length"]
    ) -> Float[torch.Tensor, "batch pos length"]:
        if self.cfg.dtype not in [torch.float32, torch.float64]:
            x = x.to(torch.float32)

        scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
            (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
        )
        x = x / scale  # [batch, pos, length]
        return self.hook_normalized(x * self.w).to(self.cfg.dtype)

Type of change

  • [x] Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes
  • [x] I have not rewritten tests relating to key interfaces which would affect backward compatibility

clarenceluo78 avatar Aug 26 '23 16:08 clarenceluo78

I'm to pushback, but currently think that hook_normalized is working as intended, because it's invariant between folding layer norm

On Sat, 26 Aug 2023, 5:15 pm Haoyan Luo, @.***> wrote:

Description

Before fix, model_cache["ln_final.hook_normalized"] will only return the RMS normalized hidden states without multiplying the final_ln weight. This might contractict with the design of this hook. I followed the implementation in LayerNorm hook to include self.w into the hook_normalized hookpoint.

I have currently found no issues related to this. Hope this helps :) Before Fix

def forward( self, x: Float[torch.Tensor, "batch pos length"] ) -> Float[torch.Tensor, "batch pos length"]: if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32)

    scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
        (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
    )
    x = self.hook_normalized(x / scale).to(self.cfg.dtype)  # [batch, pos, length]
    return x * self.w

After Fix

def forward( self, x: Float[torch.Tensor, "batch pos length"] ) -> Float[torch.Tensor, "batch pos length"]: if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32)

    scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
        (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
    )
    x = x / scale  # [batch, pos, length]
    return self.hook_normalized(x * self.w).to(self.cfg.dtype)

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You can view, comment on, or merge this pull request online at:

https://github.com/neelnanda-io/TransformerLens/pull/375 Commit Summary

File Changes

(1 file https://github.com/neelnanda-io/TransformerLens/pull/375/files)

Patch Links:

  • https://github.com/neelnanda-io/TransformerLens/pull/375.patch
  • https://github.com/neelnanda-io/TransformerLens/pull/375.diff

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/pull/375, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKNFTZGT6DUBHOLP3A3XXIOLFANCNFSM6AAAAAA37V57BU . You are receiving this because you are subscribed to this thread.Message ID: @.***>

neelnanda-io avatar Aug 26 '23 18:08 neelnanda-io

Yes, i believe the hook_normalized works fine. But I think the problem is when retrieving the final normalized output after RMSNorm, say calling llama_logits, llama_cache = model.run_with_cache(llama_tokens) and to get llama_cache["ln_final.hook_normalized"], it would return x / scale instead of x / scale * self.w since self.w is not included in hook_normalized.

clarenceluo78 avatar Aug 29 '23 15:08 clarenceluo78

@clarenceluo78 I think the key point here is that a common thing done in TransformerLens is folding the layer norm weights into the next linear layer. See https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln for details.

Therefore excluding it from hook_normalized seems to make sense as we want to be able to extract something that is invariant to whether or not fold_ln is enabled.

Let me know if you disagree however.

alan-cooney avatar Oct 13 '23 09:10 alan-cooney

I think that this should not be merged, but we should reopen and merge this PR that fixes the inconsistency that @clarenceluo78 noticed. Unless @neelnanda-io can explain why the behavior of the the hook_normalized caching in LayerNorm is correct?

ArthurConmy avatar Oct 13 '23 23:10 ArthurConmy

I am going to go ahead and close this one. If anything further needs to be done with this, let's discuss it on slack.

bryce13950 avatar May 03 '24 23:05 bryce13950