Size mismatch for base_model.model.transformer for bloomz model finetuning
After finetuning the bigscience/bloomz-7b1, I encountered this issue while doing evaluation.
│ envs/lmflow/lib/python3.9/site-packages/peft/utils/save_and_load.py:74 │
│ in set_peft_model_state_dict │
│ │
│ 71 │ │ peft_model_state_dict (dict): The state dict of the Peft model. │
│ 72 │ """ │
│ 73 │ │
│ ❱ 74 │ model.load_state_dict(peft_model_state_dict, strict=False) │
│ 75 │ if model.peft_config.peft_type != PeftType.LORA: │
│ 76 │ │ model.prompt_encoder.embedding.load_state_dict( │
│ 77 │ │ │ {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True │
│ │
│ envs/lmflow/lib/python3.9/site-packages/torch/nn/modules/module.py:160 │
│ 4 in load_state_dict │
│ │
│ 1601 │ │ │ │ │ │ ', '.join('"{}"'.format(k) for k in missing_keys))) │
│ 1602 │ │ │
│ 1603 │ │ if len(error_msgs) > 0: │
│ ❱ 1604 │ │ │ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( │
│ 1605 │ │ │ │ │ │ │ self.class.name, "\n\t".join(error_msgs))) │
│ 1606 │ │ return _IncompatibleKeys(missing_keys, unexpected_keys) │
│ 1607 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
size mismatch for base_model.model.transformer.h.0.self_attention.query_key_value.lora_A.weight: copying a param with shape torch.Size([0]) from
checkpoint, the shape in current model is torch.Size([16, 4096]).
size mismatch for base_model.model.transformer.h.0.self_attention.query_key_value.lora_B.weight: copying a param with shape torch.Size([0]) from
checkpoint, the shape in current model is torch.Size([8192, 8, 1]).
size mismatch for base_model.model.transformer.h.1.self_attention.query_key_value.lora_A.weight: copying a param with shape torch.Size([0]) from
checkpoint, the shape in current model is torch.Size([16, 4096]).
Hi, Could you share the command you are using? Thanks!
Hi, Could you share the command you are using? Thanks!
Hihi, I think I got the solution. I was using the zero3 config which causing the error. I switched to zero2 and the problem is gone. Zero2 finetuning training is much more faster than zero3. Do you know why is that so?
It's strange. I never encounter such an issue. In my case, zero3 works well as zero2.