transformers
transformers copied to clipboard
With deepspeed zero3 enabled, loading from_pretrained() and resize_token_embeddings() do not work correctly
System Info
torch 2.1.1 - CUDA 12.1 transformers 4.36.2 accelerate 0.26.0 deepspeed 0.12.3
Who can help?
@pacman100
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
This problem exists in PretrainedModel
class in modeling_utils.py
and would affect any code.
With deepspeed enabled, the model is wrapped by deepspeed engine, and normal model parameter weight
and bias
are changed: they are empty having shape = torch.Size([0]), and the actual weights are stored in ds_tensor
attributes of weight
and bias
, respectively. This leads to a few problems in modeling_utils.py
- Calling
model.state_dict().keys()
to get expected model parameters. This would use pytorch Module's original function to get state_dict, and with deepspeed enabled, this method failed to get all param keys. - Checking mismatched keys:
state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
. Heremodel_state_dict[model_key].shape
is 0, so this method failed, resulting in matched key becoming unmatched. This caused matched keys being removed from checkpoint's state_dict, and those params' weights are not loaded. -
Tied_params
: Should call accelerate'sfind_tied_parameters()
for search for tied parameters in case deepspeed is enabled, instead of relying onmodel.state_dict().items()
-
resize_token_embedding()
:- when creating new_embedding, this call is not wrapped in a deepspeed context, so the new_embedding is not managed by deepspeed.
- With the above fixed, before tying weights, the
embedding.shape
check must be wrapped in deepspeedGatheredParamters()
context.
Expected behavior
I made a fork of transformers
and modified modeling_utils.py
as in the following commit:
https://github.com/haixpham/transformers/commit/e300792ccb6fc53666b4971bab87ea7179a4e3bb
I would love to hear any feedback about my changes. I checked and compared the result values with/without deepspeed and they appeared similar.
Not sure if this helps, but I'm also having trouble reproducing accuracy of CodeLlama 13B-I trained with Zero2 using Zero3.
Meet the same problem when using this snippet to save a zero-3 model.
Gentle ping @pacman100
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hey! Pretty sure this was fixed on main!
Hey! Pretty sure this was fixed on main!
I don't think so.
- For example, for the mismatched keys in
_load_pretrained_model()
, this condition check still does not take deepspeedds_shape
into account. This is serious, because it causes checkpoint weights being not loaded at all. https://github.com/huggingface/transformers/blob/de11d0bdf0286f64616ea0d4b5778c41151a2d22/src/transformers/modeling_utils.py#L3994 while in my fix, a specific check was carried out https://github.com/haixpham/transformers/blob/a0204dbdfe1cc9a0203c9d30e3f3ec2f477c5cec/src/transformers/modeling_utils.py#L4068 - Another point in
_get_resized_embeddings()
HF code, does not wrap nn.Embedding in deepspeed context: https://github.com/huggingface/transformers/blob/de11d0bdf0286f64616ea0d4b5778c41151a2d22/src/transformers/modeling_utils.py#L2023 My fix: https://github.com/haixpham/transformers/blob/a0204dbdfe1cc9a0203c9d30e3f3ec2f477c5cec/src/transformers/modeling_utils.py#L1875
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@haixpham can you rather provide a reproducer?
@ArthurZucker you can take any example which makes use of TrainingArguments / HFDeepspeedConfig. This bug is in PretrainedModel, and it affects any existing code.
For example, the huggingface translation example with -deepspeed ds_config_zero.json
: https://github.com/huggingface/transformers/blob/main/examples/pytorch/translation/run_translation.py
- The weight loading problem is in this part (main branch): https://github.com/huggingface/transformers/blob/eb1a77bbb0a62d721e9a02e67b7e4f9e5afca08b/src/transformers/modeling_utils.py#L4089
Because Zero3 zero'ed out the weight tensor, the checkpoint tensor does not match the model weight tensor anymore, and checkpoint weights are not loaded. With this bug, when using deepspeed, weights are initialized randomly (_fast_init=False) or with zeros (when _fast_init=True), You can monitor loss value with/without deepspeed to see the difference.
- The
_get_resized_embeddings()
problem is different: the new embedding/head layers are not wrapped in deepspeed context and they will be reinitialized/not wrapped in zero3.
cc @muellerzr if you can have a look!