Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] Can't add hook to pretrained model: AssertionError: Cannot add hook blocks.0.hook_q_input if use_split_qkv_input is False

Open jbloomAus opened this issue 1 year ago • 4 comments

Describe the bug Attribution patching demo:

Code example Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. (see patching section in notebook: https://neelnanda.io/attribution-patching-demo), also in this repo.

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-17-996e57c2bc08>](https://localhost:8080/#) in <module>
     16     return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)
     17 
---> 18 clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)
     19 print("Clean Value:", clean_value)
     20 print("Clean Activations Cached:", len(clean_cache))

2 frames
[<ipython-input-17-996e57c2bc08>](https://localhost:8080/#) in get_cache_fwd_and_bwd(model, tokens, metric)
      4     def forward_cache_hook(act, hook):
      5         cache[hook.name] = act.detach()
----> 6     model.add_hook(lambda name: True, forward_cache_hook, "fwd")
      7 
      8     grad_cache = {}

[/usr/local/lib/python3.9/dist-packages/transformer_lens/hook_points.py](https://localhost:8080/#) in add_hook(self, name, hook, dir, is_permanent)
    150             for hook_point_name, hp in self.hook_dict.items():
    151                 if name(hook_point_name):
--> 152                     self.check_and_add_hook(hp, hook_point_name, hook, dir=dir, is_permanent=is_permanent)
    153 
    154     def add_perma_hook(self, name, hook, dir="fwd") -> None:

[/usr/local/lib/python3.9/dist-packages/transformer_lens/HookedTransformer.py](https://localhost:8080/#) in check_and_add_hook(self, hook_point, hook_point_name, hook, dir, is_permanent)
    155             assert self.cfg.use_attn_result, f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
    156         if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
--> 157             assert self.cfg.use_split_qkv_input, f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
    158         hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent)
    159 

AssertionError: Cannot add hook blocks.0.hook_q_input if use_split_qkv_input is False

System Info Describe the characteristic of your environment:

  • Colab notebook.

Additional context

  • We recently had PR's related to split QKV I believe.

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)

jbloomAus avatar Mar 19 '23 21:03 jbloomAus

@ArthurConmy Any ideas here?

jbloomAus avatar Mar 19 '23 21:03 jbloomAus

Oh rip, that's probably because I explicitly wrote the caching in the get cache fwd and bwd function, and I think it just does every single hook in the model. In my opinion it is incorrect to raise an error here, you should either give a warning or nothing (warnings are annoying and verbose, but so are silent bugs). Is this the kind of thing that logging levels are supposed to solve?

Maybe I should keep all my demo notebooks on their own branches again...

On Sun, 19 Mar 2023, 9:19 pm Joseph Bloom, @.***> wrote:

@ArthurConmy https://github.com/ArthurConmy Any ideas here?

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/207#issuecomment-1475402861, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKJ4WXZM3F34SVDENSDW45Z55ANCNFSM6AAAAAAWAK4QVU . You are receiving this because you are subscribed to this thread.Message ID: @.***>

neelnanda-io avatar Mar 19 '23 22:03 neelnanda-io

Logging levels could definitely help with this, we could add a logger into the package.

I was thinking we should move all the tutorials into a folder and actually run them semi-periodically to check for drift/that we haven't broken anything (until the tests are more comprehensive).

jbloomAus avatar Mar 26 '23 04:03 jbloomAus

Hi all,

First of all, thanks for putting out this amazing package, it really is great!!

I've been trying to reproduce that same demo and I managed to avoid the said AssertionError by modifying the lambda filter as:

def filter_not_qkv_input(name):
    
    excluded_substrings = [
        "_input",
        "attn.hook_result",
        "hook_q_input", 
        "hook_k_input", 
        "hook_v_input", 
        "mlp_in",
        "attn_in"
    ]
    
    return not any(sub in name for sub in excluded_substrings)

I've also changed the get_cache_fwd_and_bwd function into:

def get_cache_fwd_and_bwd(model, tokens, metric):
    torch.set_grad_enabled(True)
    model.reset_hooks()
    cache = {}
    sums_fwd = []

    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
        sums_fwd.append(torch.sum(cache[hook.name]).cpu())

    model.add_hook(filter_not_qkv_input, forward_cache_hook, "fwd")

    grad_cache = {}
    sums_bwd = []
    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
        sums_bwd.append(torch.sum(grad_cache[hook.name]).cpu())

    model.add_hook(filter_not_qkv_input, backward_cache_hook, "bwd")

    value = metric(model(tokens))
    value.backward()
    model.reset_hooks()
    
    return (
        value.item(),
        ActivationCache(cache, model),
        ActivationCache(grad_cache, model), sums_fwd, sums_bwd
    )

Where I've allowed gradients to be calculated (for the backward pass). This yields:

Clean Value: 1.0 Clean Activations Cached: 208 Clean Gradients Cached: 208 Corrupted Value: 0.0 Corrupted Activations Cached: 208 Corrupted Gradients Cached: 208

You'll also notice that I'm saving the sum of the fwd and bwd passes for each hook. I've checked that both the fwd and bwd for the clean cache are non-zero, as well as for the corrupted_fwd. However, the corrupted_bwd is always exactly 0. When I then try to compute the result for the rest of the functions (for example, attr_patch_residual) I get all 0s, because corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False) is always 0 and the einops.reduce operation returns 0s, even if I've checked that the other components (clean_residual and corrupted_residual) are not 0.

What am I missing here?

atlaie avatar Apr 30 '24 09:04 atlaie