ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

Fix state_dict key missing issue of the ZeroDDP

Open eric8607242 opened this issue 2 years ago • 7 comments

#2361 Hi, I try to dig into the root cause of this issue. After digging, I found the reason to raise the checkpoint key missing issue for GPT2 is that in the ZeroDDP the state_dict of the model is dumped by the method self.named_parameters() which returns the output of the Pytorch named_parameters().

However, in the original implementation of named_parameters() in Pytorch, the named_parameters() only return the weights with the same memory address once. This means that if there are two layers that share the same weight as the embedding layer and the final classifier layer in GPT2, the method named_parameters() would only return one of them. But the original implementation of state_dict in Pytorch returns the whole state of the module instead.

Therefore, to address this issue, I create a dict named name_to_fp32_params to store the mapping between the whole module state and the fp32 params.

# Set keep_vars=True to prevent the parameter to be detached
for name, p in module.state_dict(keep_vars=True).items():
    if p in params_to_fp32_params[name] = params_to_fp32_params[p]
        self.name_to_fp32_params[name] = params_to_fp32_params[p]

And in the method _save_to_state_dict(), instead, utilize the function named_parameters(), I use the self.module.state_dict() to get the whole model state.

for name, p in self.module.state_dict().items():
    if p is not None and name in self.name_to_fp32_params:
        fp32_p = self.name_to_fp32_params[name]
    
        assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)      
        record_parameter = param_to_save_data[fp32_p]
    else:
        record_parameter = p
    destination[prefix + name] = record_parameter

With the modification shown above, the state_dict can output the correct key value pair for the GPT2 without any impact on other models.

eric8607242 avatar Jan 06 '23 06:01 eric8607242

Thanks @eric8607242 , very impressive. Could you please solve the conflict? @1SAA please help review this PR.

feifeibear avatar Jan 06 '23 06:01 feifeibear

@feifeibear Hi, Thanks for the quick response! The conflict fix done!

eric8607242 avatar Jan 06 '23 06:01 eric8607242

Hi @eric8607242

Thanks for your great PR. 😊 But we just added a powerful function called get_static_torch_model which can generate a torch model from a GeminiDDP module. You can find the function here. Could you try this function like state_dict = get_static_torch_model(model).state_dict()? Does this function meet your requirement? I think we could use this function to implement our new strict state_dict() funtion.

1SAA avatar Jan 06 '23 07:01 1SAA

Hi @1SAA , Thanks for the great work and awesome function.

I think this is exactly what I want!

But the new function is coupled with the ZeroDDP.state_dict. Implementing the strict state_dict with the new function is a bit of a chicken-and-egg problem. And I concern this may confuse other users as I did.

eric8607242 avatar Jan 06 '23 07:01 eric8607242

Hi @eric8607242

I think we can rename the current state_dict function as _non_strict_state_dict and write a new state_dict function based on get_static_torch_model. In my opinion, this way leads to minimum editions under the current circumstances.

1SAA avatar Jan 06 '23 08:01 1SAA

Hi @1SAA , I think this is a great idea. Thanks for your suggestion!

I will follow this rule to modify the method as soon as possible.

eric8607242 avatar Jan 06 '23 08:01 eric8607242

Hi @1SAA, I rewrite the state_dict model with a new bool variable strict to specify which version ofstate_dict would be returned. The default value is True to return the state_dict as the Pytorch usage to avoid user confusion.

Look forward to your feedback. Thanks!

eric8607242 avatar Jan 07 '23 01:01 eric8607242

Hi @eric8607242

LGTM.

1SAA avatar Jan 09 '23 03:01 1SAA

Hi @1SAA , Thanks for review!

I modify the argument GeminiDDP module to ZeroDDP module in get_static_torch_model function to be compatible with the lower version (v0.1.9 and v0.1.10).

eric8607242 avatar Jan 09 '23 04:01 eric8607242