Automatic-Circuit-Discovery
Automatic-Circuit-Discovery copied to clipboard
Be more efficient with corrupted caching
for k in exp.global_cache.corrupted_cache.keys():
print(k, exp.global_cache.corrupted_cache[k].shape, k in exp.global_cache.online_cache)
returns lots of unnecessary things:
blocks.0.ln1.hook_scale torch.Size([40, 41, 8, 1]) False
blocks.0.ln1.hook_normalized torch.Size([40, 41, 8, 512]) False
...
blocks.0.attn.hook_attn_scores torch.Size([40, 8, 41, 41]) False
blocks.0.attn.hook_pattern torch.Size([40, 8, 41, 41]) False
blocks.0.attn.hook_z torch.Size([40, 41, 8, 64]) False
we should only be caching things that definitely matter. This should save memory and allow us to use bigger models when corrupted_cache_cpu=False