DeepSpeedExamples icon indicating copy to clipboard operation
DeepSpeedExamples copied to clipboard

Cannot load the previous model weights when using ZeRO 3 optimizer in DeepSpeed Chat

Open caoyu-noob opened this issue 1 year ago • 4 comments

Problem:

When I got a previously-trained model state dict file, e.g., a reward model named PATH/pytorch_model.bin. When I try to reload it for further training using ZeRO3 optimizer, an error occurs in L72 in DeepSpeed-Chat/training/utils/model/model_utils.py.

Exception information like:

size mismatch for rwtranrsformer.h.0.mlp.dense_4h_to_h.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([0]).

Possible Reason:

When using ZeRO3 optimizer, a HfDeepSpeedConfig will be created in L30 in DeepSpeed-Chat/training/utils/model/model_utils.py, then the following models will be initialized and partitioned into different GPUs automatically by HF and thus it cannot be loaded directly via load_state_dict in PyTorch.

caoyu-noob avatar Apr 25 '23 03:04 caoyu-noob

@caoyu-noob, you can use the zero_to_fp32.py script to convert the zero3 checkpoints into a regular pytorch checkpoint. You can find documentation of this script and other checkpoint conversion options here.

tjruwase avatar Apr 26 '23 04:04 tjruwase

How to solve this problem

Pattaro avatar Apr 26 '23 10:04 Pattaro

@caoyu-noob, you can use the zero_to_fp32.py script to convert the zero3 checkpoints into a regular pytorch checkpoint. You can find documentation of this script and other checkpoint conversion options here.

I think this issue is related to how to do the opposite: How can we load_state_dict a regular pytorch checkpoint for a zero3 model, instead of the other way around?

I'm still facing similar issues when loading weights of some layers with load_state_dict before doing zero-3 training. Any help or guidance on this scenario? Thank you!

XenonLamb avatar Mar 20 '24 00:03 XenonLamb

@caoyu-noob, you can use the zero_to_fp32.py script to convert the zero3 checkpoints into a regular pytorch checkpoint. You can find documentation of this script and other checkpoint conversion options here.

I think this issue is related to how to do the opposite: How can we load_state_dict a regular pytorch checkpoint for a zero3 model, instead of the other way around?

I'm still facing similar issues when loading weights of some layers with load_state_dict before doing zero-3 training. Any help or guidance on this scenario? Thank you!

Facing the same issue, any way to solve it?

Uxito-Ada avatar May 23 '24 06:05 Uxito-Ada