DeepSpeedExamples icon indicating copy to clipboard operation
DeepSpeedExamples copied to clipboard

Missing key(s) in state_dict for bias in attention blocks

Open EikeKohl opened this issue 1 year ago • 1 comments

I am trying to run step 3 of the RLHF examples using a RewardModel checkpoint that I trained using step 2 of the examples. For every step, I used the provided sh scripts and only adjusted the model / data paths. Unfortunately, I encountered the following exception:

*******************[end] Initialized Ref Model [end] (duration: 0.59s)********************
************************[start] Initializing Critic Model [start] ************************
Traceback (most recent call last):
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 516, in <module>
    main()
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 385, in main
    rlhf_engine = DeepSpeedRLHFEngine(
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py", line 57, in __init__
    self.critic = self._init_critic(
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py", line 193, in _init_critic
    critic_model = create_critic_model(
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/utils/model/model_utils.py", line 69, in create_critic_model
    critic_model.load_state_dict(
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RewardModel:
        Missing key(s) in state_dict: "rwtransformer.h.0.attn.bias", "rwtransformer.h.0.attn.masked_bias", "rwtransformer.h.1.attn.bias", "rwtransformer.h.1.attn.masked_bias", "rwtransformer.h.2.attn.bias", "rwtransformer.h.2.attn.masked_bias", "rwtransformer.h.3.attn.bias", "rwtransformer.h.3.attn.masked_bias", "rwtransformer.h.4.attn.bias", "rwtransformer.h.4.attn.masked_bias", "rwtransformer.h.5.attn.bias", "rwtransformer.h.5.attn.masked_bias", "rwtransformer.h.6.attn.bias", "rwtransformer.h.6.attn.masked_bias", "rwtransformer.h.7.attn.bias", "rwtransformer.h.7.attn.masked_bias", "rwtransformer.h.8.attn.bias", "rwtransformer.h.8.attn.masked_bias", "rwtransformer.h.9.attn.bias", "rwtransformer.h.9.attn.masked_bias", "rwtransformer.h.10.attn.bias", "rwtransformer.h.10.attn.masked_bias", "rwtransformer.h.11.attn.bias", "rwtransformer.h.11.attn.masked_bias". 
[2023-04-20 11:58:58,807] [INFO] [launch.py:428:sigkill_handler] Killing subprocess 24065

These are the keys in model.named_parameters() Before training step 2:

['wte.weight', 'wpe.weight', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.attn.c_attn.weight', 'h.0.attn.c_attn.bias', 'h.0.attn.c_proj.weight', 'h.0.attn.c_proj.bias', 'h.0.ln_2.weight', 'h.0.ln_2.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_proj.weight', 'h.0.mlp.c_proj.bias', 'h.1.ln_1.weight', 'h.1.ln_1.bias', 'h.1.attn.c_attn.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.1.ln_2.weight', 'h.1.ln_2.bias', 'h.1.mlp.c_fc.weight', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.2.ln_1.weight', 'h.2.ln_1.bias', 'h.2.attn.c_attn.weight', 'h.2.attn.c_attn.bias', 'h.2.attn.c_proj.weight', 'h.2.attn.c_proj.bias', 'h.2.ln_2.weight', 'h.2.ln_2.bias', 'h.2.mlp.c_fc.weight', 'h.2.mlp.c_fc.bias', 'h.2.mlp.c_proj.weight', 'h.2.mlp.c_proj.bias', 'h.3.ln_1.weight', 'h.3.ln_1.bias', 'h.3.attn.c_attn.weight', 'h.3.attn.c_attn.bias', 'h.3.attn.c_proj.weight', 'h.3.attn.c_proj.bias', 'h.3.ln_2.weight', 'h.3.ln_2.bias', 'h.3.mlp.c_fc.weight', 'h.3.mlp.c_fc.bias', 'h.3.mlp.c_proj.weight', 'h.3.mlp.c_proj.bias', 'h.4.ln_1.weight', 'h.4.ln_1.bias', 'h.4.attn.c_attn.weight', 'h.4.attn.c_attn.bias', 'h.4.attn.c_proj.weight', 'h.4.attn.c_proj.bias', 'h.4.ln_2.weight', 'h.4.ln_2.bias', 'h.4.mlp.c_fc.weight', 'h.4.mlp.c_fc.bias', 'h.4.mlp.c_proj.weight', 'h.4.mlp.c_proj.bias', 'h.5.ln_1.weight', 'h.5.ln_1.bias', 'h.5.attn.c_attn.weight', 'h.5.attn.c_attn.bias', 'h.5.attn.c_proj.weight', 'h.5.attn.c_proj.bias', 'h.5.ln_2.weight', 'h.5.ln_2.bias', 'h.5.mlp.c_fc.weight', 'h.5.mlp.c_fc.bias', 'h.5.mlp.c_proj.weight', 'h.5.mlp.c_proj.bias', 'h.6.ln_1.weight', 'h.6.ln_1.bias', 'h.6.attn.c_attn.weight', 'h.6.attn.c_attn.bias', 'h.6.attn.c_proj.weight', 'h.6.attn.c_proj.bias', 'h.6.ln_2.weight', 'h.6.ln_2.bias', 'h.6.mlp.c_fc.weight', 'h.6.mlp.c_fc.bias', 'h.6.mlp.c_proj.weight', 'h.6.mlp.c_proj.bias', 'h.7.ln_1.weight', 'h.7.ln_1.bias', 'h.7.attn.c_attn.weight', 'h.7.attn.c_attn.bias', 'h.7.attn.c_proj.weight', 'h.7.attn.c_proj.bias', 'h.7.ln_2.weight', 'h.7.ln_2.bias', 'h.7.mlp.c_fc.weight', 'h.7.mlp.c_fc.bias', 'h.7.mlp.c_proj.weight', 'h.7.mlp.c_proj.bias', 'h.8.ln_1.weight', 'h.8.ln_1.bias', 'h.8.attn.c_attn.weight', 'h.8.attn.c_attn.bias', 'h.8.attn.c_proj.weight', 'h.8.attn.c_proj.bias', 'h.8.ln_2.weight', 'h.8.ln_2.bias', 'h.8.mlp.c_fc.weight', 'h.8.mlp.c_fc.bias', 'h.8.mlp.c_proj.weight', 'h.8.mlp.c_proj.bias', 'h.9.ln_1.weight', 'h.9.ln_1.bias', 'h.9.attn.c_attn.weight', 'h.9.attn.c_attn.bias', 'h.9.attn.c_proj.weight', 'h.9.attn.c_proj.bias', 'h.9.ln_2.weight', 'h.9.ln_2.bias', 'h.9.mlp.c_fc.weight', 'h.9.mlp.c_fc.bias', 'h.9.mlp.c_proj.weight', 'h.9.mlp.c_proj.bias', 'h.10.ln_1.weight', 'h.10.ln_1.bias', 'h.10.attn.c_attn.weight', 'h.10.attn.c_attn.bias', 'h.10.attn.c_proj.weight', 'h.10.attn.c_proj.bias', 'h.10.ln_2.weight', 'h.10.ln_2.bias', 'h.10.mlp.c_fc.weight', 'h.10.mlp.c_fc.bias', 'h.10.mlp.c_proj.weight', 'h.10.mlp.c_proj.bias', 'h.11.ln_1.weight', 'h.11.ln_1.bias', 'h.11.attn.c_attn.weight', 'h.11.attn.c_attn.bias', 'h.11.attn.c_proj.weight', 'h.11.attn.c_proj.bias', 'h.11.ln_2.weight', 'h.11.ln_2.bias', 'h.11.mlp.c_fc.weight', 'h.11.mlp.c_fc.bias', 'h.11.mlp.c_proj.weight', 'h.11.mlp.c_proj.bias', 'ln_f.weight', 'ln_f.bias']

These are the keys of the trained RewardModel after loading it with torch.load():

['v_head.weight', 'rwtransformer.wte.weight', 'rwtransformer.wpe.weight', 'rwtransformer.h.0.ln_1.weight', 'rwtransformer.h.0.ln_1.bias', 'rwtransformer.h.0.attn.c_attn.weight', 'rwtransformer.h.0.attn.c_attn.bias', 'rwtransformer.h.0.attn.c_proj.weight', 'rwtransformer.h.0.attn.c_proj.bias', 'rwtransformer.h.0.ln_2.weight', 'rwtransformer.h.0.ln_2.bias', 'rwtransformer.h.0.mlp.c_fc.weight', 'rwtransformer.h.0.mlp.c_fc.bias', 'rwtransformer.h.0.mlp.c_proj.weight', 'rwtransformer.h.0.mlp.c_proj.bias', 'rwtransformer.h.1.ln_1.weight', 'rwtransformer.h.1.ln_1.bias', 'rwtransformer.h.1.attn.c_attn.weight', 'rwtransformer.h.1.attn.c_attn.bias', 'rwtransformer.h.1.attn.c_proj.weight', 'rwtransformer.h.1.attn.c_proj.bias', 'rwtransformer.h.1.ln_2.weight', 'rwtransformer.h.1.ln_2.bias', 'rwtransformer.h.1.mlp.c_fc.weight', 'rwtransformer.h.1.mlp.c_fc.bias', 'rwtransformer.h.1.mlp.c_proj.weight', 'rwtransformer.h.1.mlp.c_proj.bias', 'rwtransformer.h.2.ln_1.weight', 'rwtransformer.h.2.ln_1.bias', 'rwtransformer.h.2.attn.c_attn.weight', 'rwtransformer.h.2.attn.c_attn.bias', 'rwtransformer.h.2.attn.c_proj.weight', 'rwtransformer.h.2.attn.c_proj.bias', 'rwtransformer.h.2.ln_2.weight', 'rwtransformer.h.2.ln_2.bias', 'rwtransformer.h.2.mlp.c_fc.weight', 'rwtransformer.h.2.mlp.c_fc.bias', 'rwtransformer.h.2.mlp.c_proj.weight', 'rwtransformer.h.2.mlp.c_proj.bias', 'rwtransformer.h.3.ln_1.weight', 'rwtransformer.h.3.ln_1.bias', 'rwtransformer.h.3.attn.c_attn.weight', 'rwtransformer.h.3.attn.c_attn.bias', 'rwtransformer.h.3.attn.c_proj.weight', 'rwtransformer.h.3.attn.c_proj.bias', 'rwtransformer.h.3.ln_2.weight', 'rwtransformer.h.3.ln_2.bias', 'rwtransformer.h.3.mlp.c_fc.weight', 'rwtransformer.h.3.mlp.c_fc.bias', 'rwtransformer.h.3.mlp.c_proj.weight', 'rwtransformer.h.3.mlp.c_proj.bias', 'rwtransformer.h.4.ln_1.weight', 'rwtransformer.h.4.ln_1.bias', 'rwtransformer.h.4.attn.c_attn.weight', 'rwtransformer.h.4.attn.c_attn.bias', 'rwtransformer.h.4.attn.c_proj.weight', 'rwtransformer.h.4.attn.c_proj.bias', 'rwtransformer.h.4.ln_2.weight', 'rwtransformer.h.4.ln_2.bias', 'rwtransformer.h.4.mlp.c_fc.weight', 'rwtransformer.h.4.mlp.c_fc.bias', 'rwtransformer.h.4.mlp.c_proj.weight', 'rwtransformer.h.4.mlp.c_proj.bias', 'rwtransformer.h.5.ln_1.weight', 'rwtransformer.h.5.ln_1.bias', 'rwtransformer.h.5.attn.c_attn.weight', 'rwtransformer.h.5.attn.c_attn.bias', 'rwtransformer.h.5.attn.c_proj.weight', 'rwtransformer.h.5.attn.c_proj.bias', 'rwtransformer.h.5.ln_2.weight', 'rwtransformer.h.5.ln_2.bias', 'rwtransformer.h.5.mlp.c_fc.weight', 'rwtransformer.h.5.mlp.c_fc.bias', 'rwtransformer.h.5.mlp.c_proj.weight', 'rwtransformer.h.5.mlp.c_proj.bias', 'rwtransformer.h.6.ln_1.weight', 'rwtransformer.h.6.ln_1.bias', 'rwtransformer.h.6.attn.c_attn.weight', 'rwtransformer.h.6.attn.c_attn.bias', 'rwtransformer.h.6.attn.c_proj.weight', 'rwtransformer.h.6.attn.c_proj.bias', 'rwtransformer.h.6.ln_2.weight', 'rwtransformer.h.6.ln_2.bias', 'rwtransformer.h.6.mlp.c_fc.weight', 'rwtransformer.h.6.mlp.c_fc.bias', 'rwtransformer.h.6.mlp.c_proj.weight', 'rwtransformer.h.6.mlp.c_proj.bias', 'rwtransformer.h.7.ln_1.weight', 'rwtransformer.h.7.ln_1.bias', 'rwtransformer.h.7.attn.c_attn.weight', 'rwtransformer.h.7.attn.c_attn.bias', 'rwtransformer.h.7.attn.c_proj.weight', 'rwtransformer.h.7.attn.c_proj.bias', 'rwtransformer.h.7.ln_2.weight', 'rwtransformer.h.7.ln_2.bias', 'rwtransformer.h.7.mlp.c_fc.weight', 'rwtransformer.h.7.mlp.c_fc.bias', 'rwtransformer.h.7.mlp.c_proj.weight', 'rwtransformer.h.7.mlp.c_proj.bias', 'rwtransformer.h.8.ln_1.weight', 'rwtransformer.h.8.ln_1.bias', 'rwtransformer.h.8.attn.c_attn.weight', 'rwtransformer.h.8.attn.c_attn.bias', 'rwtransformer.h.8.attn.c_proj.weight', 'rwtransformer.h.8.attn.c_proj.bias', 'rwtransformer.h.8.ln_2.weight', 'rwtransformer.h.8.ln_2.bias', 'rwtransformer.h.8.mlp.c_fc.weight', 'rwtransformer.h.8.mlp.c_fc.bias', 'rwtransformer.h.8.mlp.c_proj.weight', 'rwtransformer.h.8.mlp.c_proj.bias', 'rwtransformer.h.9.ln_1.weight', 'rwtransformer.h.9.ln_1.bias', 'rwtransformer.h.9.attn.c_attn.weight', 'rwtransformer.h.9.attn.c_attn.bias', 'rwtransformer.h.9.attn.c_proj.weight', 'rwtransformer.h.9.attn.c_proj.bias', 'rwtransformer.h.9.ln_2.weight', 'rwtransformer.h.9.ln_2.bias', 'rwtransformer.h.9.mlp.c_fc.weight', 'rwtransformer.h.9.mlp.c_fc.bias', 'rwtransformer.h.9.mlp.c_proj.weight', 'rwtransformer.h.9.mlp.c_proj.bias', 'rwtransformer.h.10.ln_1.weight', 'rwtransformer.h.10.ln_1.bias', 'rwtransformer.h.10.attn.c_attn.weight', 'rwtransformer.h.10.attn.c_attn.bias', 'rwtransformer.h.10.attn.c_proj.weight', 'rwtransformer.h.10.attn.c_proj.bias', 'rwtransformer.h.10.ln_2.weight', 'rwtransformer.h.10.ln_2.bias', 'rwtransformer.h.10.mlp.c_fc.weight', 'rwtransformer.h.10.mlp.c_fc.bias', 'rwtransformer.h.10.mlp.c_proj.weight', 'rwtransformer.h.10.mlp.c_proj.bias', 'rwtransformer.h.11.ln_1.weight', 'rwtransformer.h.11.ln_1.bias', 'rwtransformer.h.11.attn.c_attn.weight', 'rwtransformer.h.11.attn.c_attn.bias', 'rwtransformer.h.11.attn.c_proj.weight', 'rwtransformer.h.11.attn.c_proj.bias', 'rwtransformer.h.11.ln_2.weight', 'rwtransformer.h.11.ln_2.bias', 'rwtransformer.h.11.mlp.c_fc.weight', 'rwtransformer.h.11.mlp.c_fc.bias', 'rwtransformer.h.11.mlp.c_proj.weight', 'rwtransformer.h.11.mlp.c_proj.bias', 'rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']

As you can see, the difference between the model before and after reward model training (after removing the 'rwtransformer' prefix) is this:

list(set(model2_params) - set(model_params))
['v_head.weight']

It looks like somewhere, the bias is inserted in the state dict (I am not using LoRA for this one)

EikeKohl avatar Apr 20 '23 12:04 EikeKohl

Adding the parameter strict=False parameter in line 70 of utils.model.model_utils.create_critic_model lets me load the checkpoint:

def create_critic_model(model_name_or_path,
                        tokenizer,
                        ds_config,
                        num_padding_at_beginning=0,
                        rlhf_training=False):
    # OPT model family always put a padding token at the beginning of the sequence,
    # we did not see this in other models but not sure if it is a general rule
    critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
                                   ds_config, rlhf_training)
    critic_model = RewardModel(
        critic_model,
        tokenizer,
        num_padding_at_beginning=num_padding_at_beginning)

    if rlhf_training:
        # critic model needs to load the weight here
        model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
        assert os.path.exists(
            model_ckpt_path
        ), f"Cannot find model checkpoint at {model_ckpt_path}"
        critic_model.load_state_dict(
            torch.load(model_ckpt_path, map_location='cpu'), strict=False)

    return critic_model

I don't know what causes the issue or what the side effects are 🤔

EikeKohl avatar Apr 20 '23 13:04 EikeKohl

Adding the parameter strict=False parameter in line 70 of utils.model.model_utils.create_critic_model lets me load the checkpoint:

def create_critic_model(model_name_or_path,
                        tokenizer,
                        ds_config,
                        num_padding_at_beginning=0,
                        rlhf_training=False):
    # OPT model family always put a padding token at the beginning of the sequence,
    # we did not see this in other models but not sure if it is a general rule
    critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
                                   ds_config, rlhf_training)
    critic_model = RewardModel(
        critic_model,
        tokenizer,
        num_padding_at_beginning=num_padding_at_beginning)

    if rlhf_training:
        # critic model needs to load the weight here
        model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
        assert os.path.exists(
            model_ckpt_path
        ), f"Cannot find model checkpoint at {model_ckpt_path}"
        critic_model.load_state_dict(
            torch.load(model_ckpt_path, map_location='cpu'), strict=False)

    return critic_model

I don't know what causes the issue or what the side effects are 🤔

That dropped the error for me too! Trying to load from OPT and GPT Neo check-points and having the same issue. Although, judging by the PyTorch docs, that could have some side-effects...

alexf-a avatar May 30 '23 22:05 alexf-a

Adding the parameter strict=False parameter in line 70 of utils.model.model_utils.create_critic_model lets me load the checkpoint:

def create_critic_model(model_name_or_path,
                        tokenizer,
                        ds_config,
                        num_padding_at_beginning=0,
                        rlhf_training=False):
    # OPT model family always put a padding token at the beginning of the sequence,
    # we did not see this in other models but not sure if it is a general rule
    critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
                                   ds_config, rlhf_training)
    critic_model = RewardModel(
        critic_model,
        tokenizer,
        num_padding_at_beginning=num_padding_at_beginning)

    if rlhf_training:
        # critic model needs to load the weight here
        model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
        assert os.path.exists(
            model_ckpt_path
        ), f"Cannot find model checkpoint at {model_ckpt_path}"
        critic_model.load_state_dict(
            torch.load(model_ckpt_path, map_location='cpu'), strict=False)

    return critic_model

I don't know what causes the issue or what the side effects are 🤔

It really works for me! THANKS A LOT!

sunyuhan19981208 avatar Jun 15 '23 06:06 sunyuhan19981208