DeepSpeedExamples
DeepSpeedExamples copied to clipboard
Missing key(s) in state_dict for bias in attention blocks
I am trying to run step 3 of the RLHF examples using a RewardModel checkpoint that I trained using step 2 of the examples. For every step, I used the provided sh scripts and only adjusted the model / data paths. Unfortunately, I encountered the following exception:
*******************[end] Initialized Ref Model [end] (duration: 0.59s)********************
************************[start] Initializing Critic Model [start] ************************
Traceback (most recent call last):
File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 516, in <module>
main()
File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 385, in main
rlhf_engine = DeepSpeedRLHFEngine(
File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py", line 57, in __init__
self.critic = self._init_critic(
File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py", line 193, in _init_critic
critic_model = create_critic_model(
File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/utils/model/model_utils.py", line 69, in create_critic_model
critic_model.load_state_dict(
File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RewardModel:
Missing key(s) in state_dict: "rwtransformer.h.0.attn.bias", "rwtransformer.h.0.attn.masked_bias", "rwtransformer.h.1.attn.bias", "rwtransformer.h.1.attn.masked_bias", "rwtransformer.h.2.attn.bias", "rwtransformer.h.2.attn.masked_bias", "rwtransformer.h.3.attn.bias", "rwtransformer.h.3.attn.masked_bias", "rwtransformer.h.4.attn.bias", "rwtransformer.h.4.attn.masked_bias", "rwtransformer.h.5.attn.bias", "rwtransformer.h.5.attn.masked_bias", "rwtransformer.h.6.attn.bias", "rwtransformer.h.6.attn.masked_bias", "rwtransformer.h.7.attn.bias", "rwtransformer.h.7.attn.masked_bias", "rwtransformer.h.8.attn.bias", "rwtransformer.h.8.attn.masked_bias", "rwtransformer.h.9.attn.bias", "rwtransformer.h.9.attn.masked_bias", "rwtransformer.h.10.attn.bias", "rwtransformer.h.10.attn.masked_bias", "rwtransformer.h.11.attn.bias", "rwtransformer.h.11.attn.masked_bias".
[2023-04-20 11:58:58,807] [INFO] [launch.py:428:sigkill_handler] Killing subprocess 24065
These are the keys in model.named_parameters()
Before training step 2:
['wte.weight', 'wpe.weight', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.attn.c_attn.weight', 'h.0.attn.c_attn.bias', 'h.0.attn.c_proj.weight', 'h.0.attn.c_proj.bias', 'h.0.ln_2.weight', 'h.0.ln_2.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_proj.weight', 'h.0.mlp.c_proj.bias', 'h.1.ln_1.weight', 'h.1.ln_1.bias', 'h.1.attn.c_attn.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.1.ln_2.weight', 'h.1.ln_2.bias', 'h.1.mlp.c_fc.weight', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.2.ln_1.weight', 'h.2.ln_1.bias', 'h.2.attn.c_attn.weight', 'h.2.attn.c_attn.bias', 'h.2.attn.c_proj.weight', 'h.2.attn.c_proj.bias', 'h.2.ln_2.weight', 'h.2.ln_2.bias', 'h.2.mlp.c_fc.weight', 'h.2.mlp.c_fc.bias', 'h.2.mlp.c_proj.weight', 'h.2.mlp.c_proj.bias', 'h.3.ln_1.weight', 'h.3.ln_1.bias', 'h.3.attn.c_attn.weight', 'h.3.attn.c_attn.bias', 'h.3.attn.c_proj.weight', 'h.3.attn.c_proj.bias', 'h.3.ln_2.weight', 'h.3.ln_2.bias', 'h.3.mlp.c_fc.weight', 'h.3.mlp.c_fc.bias', 'h.3.mlp.c_proj.weight', 'h.3.mlp.c_proj.bias', 'h.4.ln_1.weight', 'h.4.ln_1.bias', 'h.4.attn.c_attn.weight', 'h.4.attn.c_attn.bias', 'h.4.attn.c_proj.weight', 'h.4.attn.c_proj.bias', 'h.4.ln_2.weight', 'h.4.ln_2.bias', 'h.4.mlp.c_fc.weight', 'h.4.mlp.c_fc.bias', 'h.4.mlp.c_proj.weight', 'h.4.mlp.c_proj.bias', 'h.5.ln_1.weight', 'h.5.ln_1.bias', 'h.5.attn.c_attn.weight', 'h.5.attn.c_attn.bias', 'h.5.attn.c_proj.weight', 'h.5.attn.c_proj.bias', 'h.5.ln_2.weight', 'h.5.ln_2.bias', 'h.5.mlp.c_fc.weight', 'h.5.mlp.c_fc.bias', 'h.5.mlp.c_proj.weight', 'h.5.mlp.c_proj.bias', 'h.6.ln_1.weight', 'h.6.ln_1.bias', 'h.6.attn.c_attn.weight', 'h.6.attn.c_attn.bias', 'h.6.attn.c_proj.weight', 'h.6.attn.c_proj.bias', 'h.6.ln_2.weight', 'h.6.ln_2.bias', 'h.6.mlp.c_fc.weight', 'h.6.mlp.c_fc.bias', 'h.6.mlp.c_proj.weight', 'h.6.mlp.c_proj.bias', 'h.7.ln_1.weight', 'h.7.ln_1.bias', 'h.7.attn.c_attn.weight', 'h.7.attn.c_attn.bias', 'h.7.attn.c_proj.weight', 'h.7.attn.c_proj.bias', 'h.7.ln_2.weight', 'h.7.ln_2.bias', 'h.7.mlp.c_fc.weight', 'h.7.mlp.c_fc.bias', 'h.7.mlp.c_proj.weight', 'h.7.mlp.c_proj.bias', 'h.8.ln_1.weight', 'h.8.ln_1.bias', 'h.8.attn.c_attn.weight', 'h.8.attn.c_attn.bias', 'h.8.attn.c_proj.weight', 'h.8.attn.c_proj.bias', 'h.8.ln_2.weight', 'h.8.ln_2.bias', 'h.8.mlp.c_fc.weight', 'h.8.mlp.c_fc.bias', 'h.8.mlp.c_proj.weight', 'h.8.mlp.c_proj.bias', 'h.9.ln_1.weight', 'h.9.ln_1.bias', 'h.9.attn.c_attn.weight', 'h.9.attn.c_attn.bias', 'h.9.attn.c_proj.weight', 'h.9.attn.c_proj.bias', 'h.9.ln_2.weight', 'h.9.ln_2.bias', 'h.9.mlp.c_fc.weight', 'h.9.mlp.c_fc.bias', 'h.9.mlp.c_proj.weight', 'h.9.mlp.c_proj.bias', 'h.10.ln_1.weight', 'h.10.ln_1.bias', 'h.10.attn.c_attn.weight', 'h.10.attn.c_attn.bias', 'h.10.attn.c_proj.weight', 'h.10.attn.c_proj.bias', 'h.10.ln_2.weight', 'h.10.ln_2.bias', 'h.10.mlp.c_fc.weight', 'h.10.mlp.c_fc.bias', 'h.10.mlp.c_proj.weight', 'h.10.mlp.c_proj.bias', 'h.11.ln_1.weight', 'h.11.ln_1.bias', 'h.11.attn.c_attn.weight', 'h.11.attn.c_attn.bias', 'h.11.attn.c_proj.weight', 'h.11.attn.c_proj.bias', 'h.11.ln_2.weight', 'h.11.ln_2.bias', 'h.11.mlp.c_fc.weight', 'h.11.mlp.c_fc.bias', 'h.11.mlp.c_proj.weight', 'h.11.mlp.c_proj.bias', 'ln_f.weight', 'ln_f.bias']
These are the keys of the trained RewardModel after loading it with torch.load()
:
['v_head.weight', 'rwtransformer.wte.weight', 'rwtransformer.wpe.weight', 'rwtransformer.h.0.ln_1.weight', 'rwtransformer.h.0.ln_1.bias', 'rwtransformer.h.0.attn.c_attn.weight', 'rwtransformer.h.0.attn.c_attn.bias', 'rwtransformer.h.0.attn.c_proj.weight', 'rwtransformer.h.0.attn.c_proj.bias', 'rwtransformer.h.0.ln_2.weight', 'rwtransformer.h.0.ln_2.bias', 'rwtransformer.h.0.mlp.c_fc.weight', 'rwtransformer.h.0.mlp.c_fc.bias', 'rwtransformer.h.0.mlp.c_proj.weight', 'rwtransformer.h.0.mlp.c_proj.bias', 'rwtransformer.h.1.ln_1.weight', 'rwtransformer.h.1.ln_1.bias', 'rwtransformer.h.1.attn.c_attn.weight', 'rwtransformer.h.1.attn.c_attn.bias', 'rwtransformer.h.1.attn.c_proj.weight', 'rwtransformer.h.1.attn.c_proj.bias', 'rwtransformer.h.1.ln_2.weight', 'rwtransformer.h.1.ln_2.bias', 'rwtransformer.h.1.mlp.c_fc.weight', 'rwtransformer.h.1.mlp.c_fc.bias', 'rwtransformer.h.1.mlp.c_proj.weight', 'rwtransformer.h.1.mlp.c_proj.bias', 'rwtransformer.h.2.ln_1.weight', 'rwtransformer.h.2.ln_1.bias', 'rwtransformer.h.2.attn.c_attn.weight', 'rwtransformer.h.2.attn.c_attn.bias', 'rwtransformer.h.2.attn.c_proj.weight', 'rwtransformer.h.2.attn.c_proj.bias', 'rwtransformer.h.2.ln_2.weight', 'rwtransformer.h.2.ln_2.bias', 'rwtransformer.h.2.mlp.c_fc.weight', 'rwtransformer.h.2.mlp.c_fc.bias', 'rwtransformer.h.2.mlp.c_proj.weight', 'rwtransformer.h.2.mlp.c_proj.bias', 'rwtransformer.h.3.ln_1.weight', 'rwtransformer.h.3.ln_1.bias', 'rwtransformer.h.3.attn.c_attn.weight', 'rwtransformer.h.3.attn.c_attn.bias', 'rwtransformer.h.3.attn.c_proj.weight', 'rwtransformer.h.3.attn.c_proj.bias', 'rwtransformer.h.3.ln_2.weight', 'rwtransformer.h.3.ln_2.bias', 'rwtransformer.h.3.mlp.c_fc.weight', 'rwtransformer.h.3.mlp.c_fc.bias', 'rwtransformer.h.3.mlp.c_proj.weight', 'rwtransformer.h.3.mlp.c_proj.bias', 'rwtransformer.h.4.ln_1.weight', 'rwtransformer.h.4.ln_1.bias', 'rwtransformer.h.4.attn.c_attn.weight', 'rwtransformer.h.4.attn.c_attn.bias', 'rwtransformer.h.4.attn.c_proj.weight', 'rwtransformer.h.4.attn.c_proj.bias', 'rwtransformer.h.4.ln_2.weight', 'rwtransformer.h.4.ln_2.bias', 'rwtransformer.h.4.mlp.c_fc.weight', 'rwtransformer.h.4.mlp.c_fc.bias', 'rwtransformer.h.4.mlp.c_proj.weight', 'rwtransformer.h.4.mlp.c_proj.bias', 'rwtransformer.h.5.ln_1.weight', 'rwtransformer.h.5.ln_1.bias', 'rwtransformer.h.5.attn.c_attn.weight', 'rwtransformer.h.5.attn.c_attn.bias', 'rwtransformer.h.5.attn.c_proj.weight', 'rwtransformer.h.5.attn.c_proj.bias', 'rwtransformer.h.5.ln_2.weight', 'rwtransformer.h.5.ln_2.bias', 'rwtransformer.h.5.mlp.c_fc.weight', 'rwtransformer.h.5.mlp.c_fc.bias', 'rwtransformer.h.5.mlp.c_proj.weight', 'rwtransformer.h.5.mlp.c_proj.bias', 'rwtransformer.h.6.ln_1.weight', 'rwtransformer.h.6.ln_1.bias', 'rwtransformer.h.6.attn.c_attn.weight', 'rwtransformer.h.6.attn.c_attn.bias', 'rwtransformer.h.6.attn.c_proj.weight', 'rwtransformer.h.6.attn.c_proj.bias', 'rwtransformer.h.6.ln_2.weight', 'rwtransformer.h.6.ln_2.bias', 'rwtransformer.h.6.mlp.c_fc.weight', 'rwtransformer.h.6.mlp.c_fc.bias', 'rwtransformer.h.6.mlp.c_proj.weight', 'rwtransformer.h.6.mlp.c_proj.bias', 'rwtransformer.h.7.ln_1.weight', 'rwtransformer.h.7.ln_1.bias', 'rwtransformer.h.7.attn.c_attn.weight', 'rwtransformer.h.7.attn.c_attn.bias', 'rwtransformer.h.7.attn.c_proj.weight', 'rwtransformer.h.7.attn.c_proj.bias', 'rwtransformer.h.7.ln_2.weight', 'rwtransformer.h.7.ln_2.bias', 'rwtransformer.h.7.mlp.c_fc.weight', 'rwtransformer.h.7.mlp.c_fc.bias', 'rwtransformer.h.7.mlp.c_proj.weight', 'rwtransformer.h.7.mlp.c_proj.bias', 'rwtransformer.h.8.ln_1.weight', 'rwtransformer.h.8.ln_1.bias', 'rwtransformer.h.8.attn.c_attn.weight', 'rwtransformer.h.8.attn.c_attn.bias', 'rwtransformer.h.8.attn.c_proj.weight', 'rwtransformer.h.8.attn.c_proj.bias', 'rwtransformer.h.8.ln_2.weight', 'rwtransformer.h.8.ln_2.bias', 'rwtransformer.h.8.mlp.c_fc.weight', 'rwtransformer.h.8.mlp.c_fc.bias', 'rwtransformer.h.8.mlp.c_proj.weight', 'rwtransformer.h.8.mlp.c_proj.bias', 'rwtransformer.h.9.ln_1.weight', 'rwtransformer.h.9.ln_1.bias', 'rwtransformer.h.9.attn.c_attn.weight', 'rwtransformer.h.9.attn.c_attn.bias', 'rwtransformer.h.9.attn.c_proj.weight', 'rwtransformer.h.9.attn.c_proj.bias', 'rwtransformer.h.9.ln_2.weight', 'rwtransformer.h.9.ln_2.bias', 'rwtransformer.h.9.mlp.c_fc.weight', 'rwtransformer.h.9.mlp.c_fc.bias', 'rwtransformer.h.9.mlp.c_proj.weight', 'rwtransformer.h.9.mlp.c_proj.bias', 'rwtransformer.h.10.ln_1.weight', 'rwtransformer.h.10.ln_1.bias', 'rwtransformer.h.10.attn.c_attn.weight', 'rwtransformer.h.10.attn.c_attn.bias', 'rwtransformer.h.10.attn.c_proj.weight', 'rwtransformer.h.10.attn.c_proj.bias', 'rwtransformer.h.10.ln_2.weight', 'rwtransformer.h.10.ln_2.bias', 'rwtransformer.h.10.mlp.c_fc.weight', 'rwtransformer.h.10.mlp.c_fc.bias', 'rwtransformer.h.10.mlp.c_proj.weight', 'rwtransformer.h.10.mlp.c_proj.bias', 'rwtransformer.h.11.ln_1.weight', 'rwtransformer.h.11.ln_1.bias', 'rwtransformer.h.11.attn.c_attn.weight', 'rwtransformer.h.11.attn.c_attn.bias', 'rwtransformer.h.11.attn.c_proj.weight', 'rwtransformer.h.11.attn.c_proj.bias', 'rwtransformer.h.11.ln_2.weight', 'rwtransformer.h.11.ln_2.bias', 'rwtransformer.h.11.mlp.c_fc.weight', 'rwtransformer.h.11.mlp.c_fc.bias', 'rwtransformer.h.11.mlp.c_proj.weight', 'rwtransformer.h.11.mlp.c_proj.bias', 'rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']
As you can see, the difference between the model before and after reward model training (after removing the 'rwtransformer' prefix) is this:
list(set(model2_params) - set(model_params))
['v_head.weight']
It looks like somewhere, the bias is inserted in the state dict (I am not using LoRA for this one)
Adding the parameter strict=False
parameter in line 70 of utils.model.model_utils.create_critic_model
lets me load the checkpoint:
def create_critic_model(model_name_or_path,
tokenizer,
ds_config,
num_padding_at_beginning=0,
rlhf_training=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
ds_config, rlhf_training)
critic_model = RewardModel(
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning)
if rlhf_training:
# critic model needs to load the weight here
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
assert os.path.exists(
model_ckpt_path
), f"Cannot find model checkpoint at {model_ckpt_path}"
critic_model.load_state_dict(
torch.load(model_ckpt_path, map_location='cpu'), strict=False)
return critic_model
I don't know what causes the issue or what the side effects are 🤔
Adding the parameter
strict=False
parameter in line 70 ofutils.model.model_utils.create_critic_model
lets me load the checkpoint:def create_critic_model(model_name_or_path, tokenizer, ds_config, num_padding_at_beginning=0, rlhf_training=False): # OPT model family always put a padding token at the beginning of the sequence, # we did not see this in other models but not sure if it is a general rule critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, ds_config, rlhf_training) critic_model = RewardModel( critic_model, tokenizer, num_padding_at_beginning=num_padding_at_beginning) if rlhf_training: # critic model needs to load the weight here model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin') assert os.path.exists( model_ckpt_path ), f"Cannot find model checkpoint at {model_ckpt_path}" critic_model.load_state_dict( torch.load(model_ckpt_path, map_location='cpu'), strict=False) return critic_model
I don't know what causes the issue or what the side effects are 🤔
That dropped the error for me too! Trying to load from OPT and GPT Neo check-points and having the same issue. Although, judging by the PyTorch docs, that could have some side-effects...
Adding the parameter
strict=False
parameter in line 70 ofutils.model.model_utils.create_critic_model
lets me load the checkpoint:def create_critic_model(model_name_or_path, tokenizer, ds_config, num_padding_at_beginning=0, rlhf_training=False): # OPT model family always put a padding token at the beginning of the sequence, # we did not see this in other models but not sure if it is a general rule critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, ds_config, rlhf_training) critic_model = RewardModel( critic_model, tokenizer, num_padding_at_beginning=num_padding_at_beginning) if rlhf_training: # critic model needs to load the weight here model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin') assert os.path.exists( model_ckpt_path ), f"Cannot find model checkpoint at {model_ckpt_path}" critic_model.load_state_dict( torch.load(model_ckpt_path, map_location='cpu'), strict=False) return critic_model
I don't know what causes the issue or what the side effects are 🤔
It really works for me! THANKS A LOT!