Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Question] load_state_dict: copy vs assign

Open coolvision opened this issue 1 year ago • 3 comments

Question

Currently, as I understand, TransformerLens copies tensors when loading a pretrained model with HookedTransformer.from_pretrained by calling: load_and_process_state_dict -> self.load_state_dict(state_dict)

Would it make sense to avoid copying and assign them instead? Latest Pytorch has a convenient option for this:

https://pytorch.org/docs/2.1/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict load_state_dict(state_dict, strict=True, assign=False) assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers.

This would reduce memory use, but I'm not sure if there are any drawbacks -- what's the reason for copying?

Looks like the only issue is that process_weights_ feature would not work, which might be fine, optionally?

coolvision avatar Jan 15 '24 21:01 coolvision

Interesting, I wasn't aware of that parameter, thanks! Looks like it's new to torch 2.0. I think that TransformerLens doesn't require torch 2.0, so can't use it.

On the other hand, the memory issues when loading from pretrained is a pretty big problem, and this is another source of them! I'd be excited for a PR that checks whether you're in torch v2.0 or above, and uses assign=True if so, otherwise doesn't use the keyword at all.

neelnanda-io avatar Jan 15 '24 22:01 neelnanda-io

it's rather 2.1, not even in 2.0: https://pytorch.org/docs/2.0/generated/torch.jit.ScriptModule.html?highlight=load_state_dict#torch.jit.ScriptModule.load_state_dict

ok, will test it and make a PR.

One side effect of this feature -- it might might simplify support for loading quantized models. I started looking into loading HuggingFace Llama quantized with bitsandbytes, into TransformerLens. One of the issues was that after load_state_dict copying, quantization attributes were lost. Assignment might retain them.

coolvision avatar Jan 16 '24 09:01 coolvision

I constantly ran into OOMs but fixed it using assign, as suggested: #724. I'm just using my local fork for now, but if anyone wants to review/merge/fork my PR, feel free! It basically allows users to use assign for any model, rather than only 4bit quants as @coolvision already added.

cyber-chris avatar Sep 15 '24 16:09 cyber-chris