llama3 icon indicating copy to clipboard operation
llama3 copied to clipboard

Llama2 transfer to Llama3

Open Summoningg opened this issue 1 year ago • 5 comments

Can I simply transfer a llama2 task to llama3 by just loading a llama3 with transformers? Or do i need to rewrite some codes?

I loaded the llama3 and it came like

raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
	size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

and when I added the ignore_mismatched_sizes=True, it was like

Traceback (most recent call last):
  File "train.py", line 53, in <module>
    main()
  File "train.py", line 49, in main
    train(args)
  File "train.py", line 35, in train
    model = llama(args)
  File ".py", line 96, in __init__
    self.llama_model = AutoModelForCausalLM.from_pretrained(
  File "/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 484, in from_pretrained
    return model_class.from_pretrained(
  File "/lib/python3.8/site-packages/transformers/modeling_utils.py", line 2881, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3218, in _load_pretrained_model
    mismatched_keys += _find_mismatched_keys(
  File "/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3141, in _find_mismatched_keys
    and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
KeyError: 'lm_head.weight'

How to fix this? or rewrite the code?

Summoningg avatar Jun 21 '24 06:06 Summoningg

You should be able to seamlessly switch if using transformers. Please share the code you're running

subramen avatar Jul 03 '24 17:07 subramen

You should be able to seamlessly switch if using transformers. Please share the code you're running

Hello and thank you for replying. The code went wrong when loading the llama

`print('Loading LLAMA')
        self.llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_model, use_fast=False)
        self.llama_tokenizer.pad_token_id = 0
        if args.low_resource:
            self.llama_model = LlamaForCausalLM.from_pretrained(
                args.llama_model,
                torch_dtype=torch.float16,
                load_in_8bit=True,
                device_map="auto"
            )
        else:
            self.llama_model = LlamaForCausalLM.from_pretrained(
                args.llama_model,
                torch_dtype=torch.float16,
            )
         
        if args.llm_use_lora:
            self.embed_tokens = self.llama_model.get_input_embeddings()
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.llm_r, lora_alpha=args.llm_alpha, lora_dropout=args.lora_dropout
            )
            self.llama_model = get_peft_model(self.llama_model, peft_config)
            self.llama_model.print_trainable_parameters()
            print('Loading LLAMA LoRA Done')         
        else:
            self.embed_tokens = self.llama_model.get_input_embeddings()
            for name, param in self.llama_model.named_parameters():
                param.requires_grad = False
            print('Loading LLAMA Done')`

And the llama3-8B I use is the version downloaded from meta website. Is that possible something went wrong when downloading llama?

Summoningg avatar Jul 04 '24 03:07 Summoningg

The code snippet you shared should work. Can you confirm what is the value of args.llama_model? It should be something like meta-llama/Meta-Llama-3.1-8B-Instruct if you are using the HF api

subramen avatar Jul 31 '24 17:07 subramen

The code snippet you shared should work. Can you confirm what is the value of args.llama_model? It should be something like meta-llama/Meta-Llama-3.1-8B-Instruct if you are using the HF api

I have same question. In my code, the value of args.llama_model is the model path I downloaded

Tzx11 avatar Jan 03 '25 02:01 Tzx11

Hi, was a resolution found to this issue? I am also facing it. Thanks!

ShristiDasBiswas avatar Aug 06 '25 15:08 ShristiDasBiswas