Easy-Transformer
Easy-Transformer copied to clipboard
[Bug Report] Can't add hook to pretrained model: AssertionError: Cannot add hook blocks.0.hook_q_input if use_split_qkv_input is False
Describe the bug Attribution patching demo:
Code example Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. (see patching section in notebook: https://neelnanda.io/attribution-patching-demo), also in this repo.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[<ipython-input-17-996e57c2bc08>](https://localhost:8080/#) in <module>
16 return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)
17
---> 18 clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)
19 print("Clean Value:", clean_value)
20 print("Clean Activations Cached:", len(clean_cache))
2 frames
[<ipython-input-17-996e57c2bc08>](https://localhost:8080/#) in get_cache_fwd_and_bwd(model, tokens, metric)
4 def forward_cache_hook(act, hook):
5 cache[hook.name] = act.detach()
----> 6 model.add_hook(lambda name: True, forward_cache_hook, "fwd")
7
8 grad_cache = {}
[/usr/local/lib/python3.9/dist-packages/transformer_lens/hook_points.py](https://localhost:8080/#) in add_hook(self, name, hook, dir, is_permanent)
150 for hook_point_name, hp in self.hook_dict.items():
151 if name(hook_point_name):
--> 152 self.check_and_add_hook(hp, hook_point_name, hook, dir=dir, is_permanent=is_permanent)
153
154 def add_perma_hook(self, name, hook, dir="fwd") -> None:
[/usr/local/lib/python3.9/dist-packages/transformer_lens/HookedTransformer.py](https://localhost:8080/#) in check_and_add_hook(self, hook_point, hook_point_name, hook, dir, is_permanent)
155 assert self.cfg.use_attn_result, f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
156 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
--> 157 assert self.cfg.use_split_qkv_input, f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
158 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent)
159
AssertionError: Cannot add hook blocks.0.hook_q_input if use_split_qkv_input is False
System Info Describe the characteristic of your environment:
- Colab notebook.
Additional context
- We recently had PR's related to split QKV I believe.
Checklist
- [x ] I have checked that there is no similar issue in the repo (required)
@ArthurConmy Any ideas here?
Oh rip, that's probably because I explicitly wrote the caching in the get cache fwd and bwd function, and I think it just does every single hook in the model. In my opinion it is incorrect to raise an error here, you should either give a warning or nothing (warnings are annoying and verbose, but so are silent bugs). Is this the kind of thing that logging levels are supposed to solve?
Maybe I should keep all my demo notebooks on their own branches again...
On Sun, 19 Mar 2023, 9:19 pm Joseph Bloom, @.***> wrote:
@ArthurConmy https://github.com/ArthurConmy Any ideas here?
— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/207#issuecomment-1475402861, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKJ4WXZM3F34SVDENSDW45Z55ANCNFSM6AAAAAAWAK4QVU . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Logging levels could definitely help with this, we could add a logger into the package.
I was thinking we should move all the tutorials into a folder and actually run them semi-periodically to check for drift/that we haven't broken anything (until the tests are more comprehensive).
Hi all,
First of all, thanks for putting out this amazing package, it really is great!!
I've been trying to reproduce that same demo and I managed to avoid the said AssertionError by modifying the lambda filter as:
def filter_not_qkv_input(name):
excluded_substrings = [
"_input",
"attn.hook_result",
"hook_q_input",
"hook_k_input",
"hook_v_input",
"mlp_in",
"attn_in"
]
return not any(sub in name for sub in excluded_substrings)
I've also changed the get_cache_fwd_and_bwd function into:
def get_cache_fwd_and_bwd(model, tokens, metric):
torch.set_grad_enabled(True)
model.reset_hooks()
cache = {}
sums_fwd = []
def forward_cache_hook(act, hook):
cache[hook.name] = act.detach()
sums_fwd.append(torch.sum(cache[hook.name]).cpu())
model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")
grad_cache = {}
sums_bwd = []
def backward_cache_hook(act, hook):
grad_cache[hook.name] = act.detach()
sums_bwd.append(torch.sum(grad_cache[hook.name]).cpu())
model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")
value = metric(model(tokens))
value.backward()
model.reset_hooks()
return (
value.item(),
ActivationCache(cache, model),
ActivationCache(grad_cache, model), sums_fwd, sums_bwd
)
Where I've allowed gradients to be calculated (for the backward pass). This yields:
Clean Value: 1.0 Clean Activations Cached: 208 Clean Gradients Cached: 208 Corrupted Value: 0.0 Corrupted Activations Cached: 208 Corrupted Gradients Cached: 208
You'll also notice that I'm saving the sum of the fwd and bwd passes for each hook. I've checked that both the fwd and bwd for the clean cache are non-zero, as well as for the corrupted_fwd. However, the corrupted_bwd is always exactly 0. When I then try to compute the result for the rest of the functions (for example, attr_patch_residual) I get all 0s, because corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False) is always 0 and the einops.reduce operation returns 0s, even if I've checked that the other components (clean_residual and corrupted_residual) are not 0.
What am I missing here?