Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Proposal] [Low-Priority] Improve LN1's hooks

Open ArthurConmy opened this issue 1 year ago • 0 comments

Background: When I implemented use_split_qkv_input here, I changed the attention module to take three inputs (to query, key and value). Even when this feature is not enabled, we compute the layer-normalized query, key and value inputs separately which may not work properly in some cases.

https://github.com/neelnanda-io/TransformerLens/blob/c6f417d/transformer_lens/components.py#L920#L924

New behavior: This change means that hooks on ln1 will get called three times. This has downsides when for example we run_with_cache where we will overwrite the cache for the LayerNorm three times.

How big a deal is this?: For most applications I don't think this matters, so plausibly people should work on another issue rather than this. But I thought I'd record it.

What may we change: Adding an extra dimension of length 3 to the LN1 input instead of calling it 3 times could fix this.

  • [ :white_check_mark: ] I have checked that there is no similar issue in the repo (required)

ArthurConmy avatar Jul 04 '23 16:07 ArthurConmy