codebook-features icon indicating copy to clipboard operation
codebook-features copied to clipboard

[Question] Creating codebooks for Q, K, V vectors

Open cmathw opened this issue 1 year ago • 1 comments

In codebook_features/models.py, I can see a method for attaching codebooks to each attention block's query, key and value vectors:

https://github.com/taufeeque9/codebook-features/blob/a37ea8fe7d4d39298aaea042a078d09401396edc/codebook_features/models.py#L1439C1-L1460C36

After training a model with these codebooks attached though, it does not seem possible to convert this model to a HookedTransformer model (doing so raises AttributeError: 'HookedTransformer' object has no attribute 'qkv_key'). What is the status of using qkv codebooks and converting to a HookedTransformer model currently? Happy to write a PR if this needs integrating with HookedTransformerCodebookModel class to work.

cmathw avatar Jan 30 '24 14:01 cmathw

Currently, qkv codebooks work with only those models where they are stored in a single parameter like in gpt2 or pythia models. I have created the branch tf/qkv to also add support for the models where they are stored independently.

With the above code, you should be able to load the qkv codebooks in HookedTransformerCodebookModel. However, since transformer_lens stores the qkv matrices as nn.Parameter and performs einsum operation directly using W_Q, W_K and W_V, the forward function of HookedTransformer throws an error when the codebook wrapper is applied on the qkv parameters.

To make qkv codebooks work with HookedTransformerCodebookModel, we'll need to create a custom class and override the forward function of the attention block of HookedTransformer to work with or without codebooks. Essentially converting W_Q, W_K and W_V from nn.Parameters to nn.Linear modules which would allow us to apply the wrappers on top of the qkv modules.

If you can implement the above (or have a better less convoluted implementation), I'd be happy to review the PR!

taufeeque9 avatar Feb 06 '24 02:02 taufeeque9