Easy-Transformer
Easy-Transformer copied to clipboard
HookedSAETransformer
Description
Implements HookedSAETransformer - an extension of HookedTransformer that enables attachment of SAEs to activations. I implemented this for MATS research, and found it useful, so I'm excited to open source it. Feedback is welcome!
Here's a demo notebook: https://colab.research.google.com/github/ckkissane/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb
Type of change
Please delete options that are not relevant.
- [ x] New feature (non-breaking change which adds functionality)
- [ x] This change requires a documentation update
Checklist:
- [ x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [ x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
- [ x] I have not rewritten tests relating to key interfaces which would affect backward compatibility
Just read through the notebook.
For what it's worth, I strongly dislike the design choice of having a persistent state about which SAEs are turned on or off in forward passes. I think using this feature is likely to lead to confusion and bugs for users, where they think that certain SAEs are being applied when they are not. I think the interface of model.run_with_saes(input, act_names = saes_to_use)
is much more explicit and easy to use. It's also more similar to the behavior of transformer lens run_with_hooks
(although not identical -- maybe it's worth thinking about how to unify the APIs more cleanly).
Of course, it's possible to use the current code just in the model.run_with_saes(input, act_names = saes_to_use)
style and ignore the global on/off state. But this is still error-prone, as model.run_with_saes
has unintuitive side effects. For instance, consider the following code:
model = ... # a HookedSAETransformer with all SAEs turned on
outputA = model(input)
outputB = model.run_with_saes(input, act_names = [sae1, sae2])
outputC = model(input)
I'd strongly expect outputA and outputC to be the same! But as I understand the current proposal, outputA is the model with all SAEs and outputC is the model without any SAEs.
I would find this deeply confusing as a user if I hadn't read the documentation carefully. Even knowing how it works I'd find it annoying to work around.
Making it so model.run_with_saes
resets the SAE states to the same state as before the function was called would be a big improvement. But I'd much prefer making SAEs non-stateful, and having three different ways to run the model: one with no SAEs, one with all SAEs, and one with a specifically chosen subset of SAEs.
To be clear, I'm not a decison-maker for Transformer Lens. You all should feel welcome to disregard my opinion.
Making it so
model.run_with_saes
resets the SAE states to the same state as before the function was called would be a big improvement
@nix-apollo I suggested this here https://github.com/neelnanda-io/TransformerLens/pull/536#discussion_r1569800199 but deferred to Connor and Joseph who have used the library.
Persistent state seems OK to me? TransformerLens already has add_hook and reset_hooks which are really similar to attach_sae and turn_off_all_saes
To me HookedSAETransformer.attach_sae
seems more analogous to add_hook
. I like having this version of statefulness! Having an object that represents "GPT2 with this set of SAEs attached" seems useful to me. But being able to represent "GPT2 with these 12 SAEs attached, except only 4 of them are turned on" doesn't seem useful.
But I generally dislike persistent state, and avoid it when possible. This includes avoiding using add_hook
and reset_hook
whenever possible in normal TL. Maybe my opinion is too extreme here -- again, do feel free to ignore it!
The flag is probably a fine middle ground. I still personally think the on/off feature adds a bunch of complexity for features that are more dangerous than useful.
Thanks for the feedback Nix! I think that there's cases where I want a stateful workflow, eg I have a single residual stream SAE that I want to attach to the model, and then just do normal mech interp to. Or I have a fixed suite of SAEs, and want to include all of them (plus error terms) and do circuit analysis. I also think there's cases where I want an ephemeral stateless workflow, eg if I'm analysing feature splitting and have 8 different SAEs on the same layer's residual stream of different widths, and want to easily swap them out for each other. So it seems clearly good to support both. IMO the best solution to avoid footguns is to have run_with_saes
return the SAEs to their original state after it runs, and to support both attach_saes/turn_all_saes_off
and run_with_saes/run_with_saes_and_cache/run_with_saes_and_hooks
@neelnanda-io I moved discussion to the Open Source slack about this issue as the interface is better.
TLDR support both permanently attached SAEs with add_sae and transiently attached SAEs with run_with_SAEs matching transformer lens symtax, do not use on/off at all