[Bug Report] Unable to Llama 3 70b on multigpu in 4bit
Unable to Llama 3 70b on multigpu
base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-70B-Instruct', torch_dtype=torch.bfloat16, device_map="auto", load_in_4bit=True)
model = HookedTransformer.from_pretrained(
'meta-llama/Meta-Llama-3-70B-Instruct',
hf_model=base_model,
fold_ln=False,
fold_value_biases=False,
)
errors with
File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HookedTransformer:
size mismatch for blocks.0.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.0.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.1.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.1.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.2.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.2.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.3.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
size mismatch for blocks.3.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]).
... (etc)
Are you using the latest version on github ? There was some issues with llama models that were fixed.
You can update with pip install --force-reinstall --no-deps git+https://github.com/neelnanda-io/TransformerLens/
FWIW, seeing the same error on the current version.
Same error for Llama3-8b. Trying to run it on a single GPU in 4-bit.
Python version: 3.11 TransformerLens: 1.18 Transformers: 4.41.0
Also tried using the latest version from github and still facing the same issue.
Same problem here! @me301 @noisefloordev @Butanium and @winglian, has any of you solved it yet?
Yeah, I got this problem with a llama-2-70b model too.
Hey all, I am making this my second priority at the moment. One thing to note is that multi-gpu support hasn't really been made very clear in the project. I have had a couple people bring it to my attention, and in most cases it has been a matter of setting n_devices={gpu_count} in the HookedTransformer function from_pretrained. From what I see here, I don't think it's the same issue. However, while you are all waiting for me to be able to address this issue, if you want to try setting that, it may make a difference when dealing with multiple GPUs.
Getting the same for Llama-3-8b, 4 bit, on a single GPU; transformer_lens 2.1.0, Loading via from_pretrained_no_processing with the hf_model parameter. from_pretrained also throws an error.
Having this same issue for Llama3-8b. Trying to run it on a single GPU in 4-bit mode.