LLaMA-Factory
LLaMA-Factory copied to clipboard
The fine-tuned Gemma model encounters an error when loaded through vllm: ```KeyError: 'lm_head.weight'```
Reminder
- [X] I have read the README and searched the existing issues.
Reproduction
deepspeed --include="localhost:0,1,2,3,4,5,6,7" src/train_bash.py
--stage sft
--do_train
--model_name_or_path gemma-7b
--dataset XXX.json
--overwrite_cache
--template gemma
--finetuning_type full
--cutoff_len 3072
--output_dir ${OUTPUT}
--per_device_train_batch_size 2
--gradient_accumulation_steps 4
--lr_scheduler_type cosine
--dataloader_num_workers 5
--logging_steps 10
--save_steps 500
--learning_rate 1e-5
--save_only_model
--num_train_epochs 15
--bf16
--save_total_limit 2
--report_to wandb
--flash_attn
--deepspeed configs/ds_stage_2.json
After fine-tuning all parameters of the gemma-7b model, it can be decoded normally using src/web_demo.py
.
However, when loading the model through vllm, the following error is:
File "/opt/conda/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 109, in __init__
self.llm_engine = LLMEngine.from_engine_args(engine_args)
File "/opt/conda/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 391, in from_engine_args
engine = cls(*engine_configs,
File "/opt/conda/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 128, in __init__
self._init_workers()
File "/opt/conda/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 181, in _init_workers
self._run_workers("load_model")
File "/opt/conda/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1041, in _run_workers
driver_worker_output = getattr(self.driver_worker,
File "/opt/conda/lib/python3.10/site-packages/vllm/worker/worker.py", line 100, in load_model
self.model_runner.load_model()
File "/opt/conda/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 88, in load_model
self.model = get_model(self.model_config,
File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/utils.py", line 52, in get_model
return get_model_fn(model_config, device_config, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/model_loader.py", line 86, in get_model
model.load_weights(model_config.model, model_config.download_dir,
File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/models/gemma.py", line 337, in load_weights
param = params_dict[name]
KeyError: 'lm_head.weight'
The script for loading the model in VLLM is as follows:
llm = LLM(
model=model_path,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
)
Expected behavior
No response
System Info
-
transformers
version: 4.38.2 - Platform: Linux-4.19.91-kangaroo.2.10.11.6c7cc891b.alios7.x86_64-x86_64-with-glibc2.35
- Python version: 3.10.13
- Huggingface_hub version: 0.21.3
- Safetensors version: 0.4.1
- Accelerate version: 0.27.2
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Tensorflow version (GPU?): 2.14.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- vllm: 0.3.3
Others
No response