zero3 checkpoint frozen params
-
Enable checkpoint load/save of frozen params in zero stage 3.
-
Fix #3090
-
Pending task: Update zero_to_fp32.py to recover frozen weights.
@stas00, FYI
I tried it out - and when the checkpoint is saved, I get almost all frozen weights saved with size[0]
python tools/convert_checkpoint/inspect_checkpoint.py /hf/m4-master-3/save_dir/opt_step-10/accelerator_state/pytorch_model/zero_pp_rank_0_mp_rank_00_model_states.pt
loading checkpoint file: /hf/m4-master-3/save_dir/opt_step-10/accelerator_state/pytorch_model/zero_pp_rank_0_mp_rank_00_model_states.pt
[tensor] module.lm_head.weight = torch.Size([0])
[tensor] module.lm_head.additional_fc.weight = torch.Size([0])
[tensor] module.model.decoder.embed_tokens.weight = torch.Size([0])
[...]
I think they need to be gathered before saving.
But we probably shouldn't do that on every process as it'd be quite slow if the model has 50% frozen weights. if it's the same weights saving it once should be enough (at least on the shared fs, it won't work on non-shared fs).
the following will do the gathering:
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index 8c31a9d6..8b91e242 100644
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -357,7 +357,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
param_groups = []
for param_group in self.optimizer.param_groups:
frozen_params = [p for p in param_group["params"] if not p.requires_grad]
- param_groups.append(frozen_params)
+ with deepspeed.zero.GatheredParameters(frozen_params, modifier_rank=None):
+ param_groups.append(frozen_params)
return param_groups
def _setup_for_real_optimizer(self):
but the saved tensors still appear to be of size 0. so that fix doesn't seem to be it.
Ah, I see - the original code will never succeed because frozen params aren't in optimizer.param_groups
I'm also thinking would this even work if there is a huge model with a lot of frozen params? There might not be enough memory to gather them all. Perhaps should save their fp16 shards instead? that would be much faster.
Hi @stas00 and @tjruwase, thanks for your work on this. I'm just checking to see if this would fix an error I'm getting using DeepSpeed and LoRA. Let me know if this isn't the place to ask.
I'm able to train "t5" using DeepSpeed Stage 3 and LoRA, however when I run the load_state_dict_from_zero_checkpoint command I get an error KeyError: '_forward_module.model.base_model.model.encoder.embed_tokens.weight'
Thanks again for all your help!
Hi @stas00 and @tjruwase, thanks for your work on this. I'm just checking to see if this would fix an error I'm getting using DeepSpeed and LoRA. Let me know if this isn't the place to ask.
I'm able to train "t5" using DeepSpeed Stage 3 and LoRA, however when I run the
load_state_dict_from_zero_checkpointcommand I get an errorKeyError: '_forward_module.model.base_model.model.encoder.embed_tokens.weight'Thanks again for all your help!
@shaankhosla, thanks for your interest. Please open a new ticket for this problem. It would be very helpful to provide more details for reproducing the problem in that ticket.
Here it is: #3291 :)
Thank you for the quick solving and merge, Tunji and the team!