Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Bug Report] Unable to Llama 3 70b on multigpu in 4bit

Open winglian opened this issue 1 year ago • 8 comments

Unable to Llama 3 70b on multigpu

base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-70B-Instruct', torch_dtype=torch.bfloat16, device_map="auto", load_in_4bit=True)
model = HookedTransformer.from_pretrained(
    'meta-llama/Meta-Llama-3-70B-Instruct',
    hf_model=base_model,
    fold_ln=False,
    fold_value_biases=False,
)

errors with

  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict                                                                                                                                 
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                                                                                                                                                                                
RuntimeError: Error(s) in loading state_dict for HookedTransformer:                                                                                                                                                                                          
        size mismatch for blocks.0.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.0.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.1.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.1.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.2.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.2.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.3.attn._W_K: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        size mismatch for blocks.3.attn._W_V: copying a param with shape torch.Size([4194304, 1]) from checkpoint, the shape in current model is torch.Size([8, 8192, 128]). 
        ... (etc)

winglian avatar May 03 '24 03:05 winglian

Are you using the latest version on github ? There was some issues with llama models that were fixed. You can update with pip install --force-reinstall --no-deps git+https://github.com/neelnanda-io/TransformerLens/

Butanium avatar May 07 '24 14:05 Butanium

FWIW, seeing the same error on the current version.

noisefloordev avatar May 10 '24 03:05 noisefloordev

Same error for Llama3-8b. Trying to run it on a single GPU in 4-bit.

Python version: 3.11 TransformerLens: 1.18 Transformers: 4.41.0

Also tried using the latest version from github and still facing the same issue.

me301 avatar May 20 '24 11:05 me301

Same problem here! @me301 @noisefloordev @Butanium and @winglian, has any of you solved it yet?

winnieyangwannan avatar May 23 '24 00:05 winnieyangwannan

Yeah, I got this problem with a llama-2-70b model too.

jukofyork avatar Jun 05 '24 12:06 jukofyork

Hey all, I am making this my second priority at the moment. One thing to note is that multi-gpu support hasn't really been made very clear in the project. I have had a couple people bring it to my attention, and in most cases it has been a matter of setting n_devices={gpu_count} in the HookedTransformer function from_pretrained. From what I see here, I don't think it's the same issue. However, while you are all waiting for me to be able to address this issue, if you want to try setting that, it may make a difference when dealing with multiple GPUs.

bryce13950 avatar Jun 06 '24 15:06 bryce13950

Getting the same for Llama-3-8b, 4 bit, on a single GPU; transformer_lens 2.1.0, Loading via from_pretrained_no_processing with the hf_model parameter. from_pretrained also throws an error.

rovle avatar Jun 14 '24 21:06 rovle

Having this same issue for Llama3-8b. Trying to run it on a single GPU in 4-bit mode.

TacticalSpoon331 avatar Oct 09 '24 15:10 TacticalSpoon331