DeepSpeedExamples
DeepSpeedExamples copied to clipboard
Cannot load the previous model weights when using ZeRO 3 optimizer in DeepSpeed Chat
Problem:
When I got a previously-trained model state dict file, e.g., a reward model named PATH/pytorch_model.bin
. When I try to reload it for further training using ZeRO3 optimizer, an error occurs in L72 in DeepSpeed-Chat/training/utils/model/model_utils.py
.
Exception information like:
size mismatch for rwtranrsformer.h.0.mlp.dense_4h_to_h.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([0]).
Possible Reason:
When using ZeRO3 optimizer, a HfDeepSpeedConfig
will be created in L30 in DeepSpeed-Chat/training/utils/model/model_utils.py
, then the following models will be initialized and partitioned into different GPUs automatically by HF and thus it cannot be loaded directly via load_state_dict
in PyTorch.
@caoyu-noob, you can use the zero_to_fp32.py
script to convert the zero3 checkpoints into a regular pytorch checkpoint. You can find documentation of this script and other checkpoint conversion options here.
How to solve this problem
@caoyu-noob, you can use the
zero_to_fp32.py
script to convert the zero3 checkpoints into a regular pytorch checkpoint. You can find documentation of this script and other checkpoint conversion options here.
I think this issue is related to how to do the opposite: How can we load_state_dict a regular pytorch checkpoint for a zero3 model, instead of the other way around?
I'm still facing similar issues when loading weights of some layers with load_state_dict before doing zero-3 training. Any help or guidance on this scenario? Thank you!
@caoyu-noob, you can use the
zero_to_fp32.py
script to convert the zero3 checkpoints into a regular pytorch checkpoint. You can find documentation of this script and other checkpoint conversion options here.I think this issue is related to how to do the opposite: How can we load_state_dict a regular pytorch checkpoint for a zero3 model, instead of the other way around?
I'm still facing similar issues when loading weights of some layers with load_state_dict before doing zero-3 training. Any help or guidance on this scenario? Thank you!
Facing the same issue, any way to solve it?