Qwen2.5 icon indicating copy to clipboard operation
Qwen2.5 copied to clipboard

qwen2,zero3保存的checkpoint使用vllm报错的问题

Open cooper12121 opened this issue 6 months ago • 0 comments

Excuse me I finetune qwen2-1.5b using deepspeed zero3, but when i call the saved checkpoint using vllm ,some errors happened

assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
AssertionError

someone says that this's the problem seems to be caused by the model file not being saved completely after zero3 fine-tuning.https://github.com/vllm-project/vllm/issues/3813. however, stage1and 2 are ok.

here is my code used to save checkpoint in zero3

def save_checkpoint(args, output_dir, model, tokenizer, total_step):
    print_rank_0("saving the final model ...", args.global_rank)
    output_dir = os.path.join(output_dir, f"chpk-{total_step}")
    os.makedirs(output_dir, exist_ok=True)
    # model = convert_lora_to_linear_layer(model)
    if args.global_rank == 0:
        save_hf_format(model, tokenizer, args, output_dir=output_dir)
    if args.zero_stage == 3:
        # For zero stage 3, each gpu only has a part of the model, so we need a special save function
        save_zero_three_model(
            model, args.global_rank, output_dir, zero_stage=args.zero_stage
        )
def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
    zero_stage_3 = zero_stage == 3
    os.makedirs(save_dir, exist_ok=True)
    WEIGHTS_NAME = "pytorch_model.bin"
    output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
    model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
    if not zero_stage_3:
        if global_rank == 0:
            torch.save(model_to_save.state_dict(), output_model_file)
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():
            if hasattr(v, "ds_id"):
                with deepspeed.zero.GatheredParameters(
                    _z3_params_to_fetch([v]), enabled=zero_stage_3
                ):
                    v_p = v.data.cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            torch.save(output_state_dict, output_model_file)
        del output_state_dict
def _z3_params_to_fetch(param_list):
    return [
        p
        for p in param_list
        if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
    ]

can you help me figure out what't wrong with this code? Thank you very much.

cooper12121 avatar Aug 03 '24 04:08 cooper12121