Automatic-Circuit-Discovery icon indicating copy to clipboard operation
Automatic-Circuit-Discovery copied to clipboard

Be more efficient with corrupted caching

Open ArthurConmy opened this issue 1 year ago • 0 comments

for k in exp.global_cache.corrupted_cache.keys():
    print(k, exp.global_cache.corrupted_cache[k].shape, k in exp.global_cache.online_cache)

returns lots of unnecessary things:

blocks.0.ln1.hook_scale torch.Size([40, 41, 8, 1]) False
blocks.0.ln1.hook_normalized torch.Size([40, 41, 8, 512]) False
...
blocks.0.attn.hook_attn_scores torch.Size([40, 8, 41, 41]) False
blocks.0.attn.hook_pattern torch.Size([40, 8, 41, 41]) False
blocks.0.attn.hook_z torch.Size([40, 41, 8, 64]) False

we should only be caching things that definitely matter. This should save memory and allow us to use bigger models when corrupted_cache_cpu=False

ArthurConmy avatar Oct 01 '23 15:10 ArthurConmy