Qwen2.5
Qwen2.5 copied to clipboard
qwen2,zero3保存的checkpoint使用vllm报错的问题
Excuse me I finetune qwen2-1.5b using deepspeed zero3, but when i call the saved checkpoint using vllm ,some errors happened
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
AssertionError
someone says that this's the problem seems to be caused by the model file not being saved completely after zero3 fine-tuning.https://github.com/vllm-project/vllm/issues/3813. however, stage1and 2 are ok.
here is my code used to save checkpoint in zero3
def save_checkpoint(args, output_dir, model, tokenizer, total_step):
print_rank_0("saving the final model ...", args.global_rank)
output_dir = os.path.join(output_dir, f"chpk-{total_step}")
os.makedirs(output_dir, exist_ok=True)
# model = convert_lora_to_linear_layer(model)
if args.global_rank == 0:
save_hf_format(model, tokenizer, args, output_dir=output_dir)
if args.zero_stage == 3:
# For zero stage 3, each gpu only has a part of the model, so we need a special save function
save_zero_three_model(
model, args.global_rank, output_dir, zero_stage=args.zero_stage
)
def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
zero_stage_3 = zero_stage == 3
os.makedirs(save_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, "ds_id"):
with deepspeed.zero.GatheredParameters(
_z3_params_to_fetch([v]), enabled=zero_stage_3
):
v_p = v.data.cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p
if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict
def _z3_params_to_fetch(param_list):
return [
p
for p in param_list
if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
can you help me figure out what't wrong with this code? Thank you very much.