Easy-Transformer
Easy-Transformer copied to clipboard
[Proposal] [Low-Priority] Improve LN1's hooks
Background: When I implemented use_split_qkv_input
here, I changed the attention module to take three inputs (to query, key and value). Even when this feature is not enabled, we compute the layer-normalized query, key and value inputs separately which may not work properly in some cases.
https://github.com/neelnanda-io/TransformerLens/blob/c6f417d/transformer_lens/components.py#L920#L924
New behavior: This change means that hooks on ln1
will get called three times. This has downsides when for example we run_with_cache
where we will overwrite the cache for the LayerNorm three times.
How big a deal is this?: For most applications I don't think this matters, so plausibly people should work on another issue rather than this. But I thought I'd record it.
What may we change: Adding an extra dimension of length 3 to the LN1 input instead of calling it 3 times could fix this.
- [ :white_check_mark: ] I have checked that there is no similar issue in the repo (required)