ColossalAI
ColossalAI copied to clipboard
Fix state_dict key missing issue of the ZeroDDP
#2361
Hi,
I try to dig into the root cause of this issue.
After digging, I found the reason to raise the checkpoint key missing issue for GPT2 is that in the ZeroDDP
the state_dict of the model is dumped by the method self.named_parameters()
which returns the output of the Pytorch named_parameters()
.
However, in the original implementation of named_parameters()
in Pytorch, the named_parameters()
only return the weights with the same memory address once.
This means that if there are two layers that share the same weight as the embedding layer and the final classifier layer in GPT2, the method named_parameters()
would only return one of them.
But the original implementation of state_dict
in Pytorch returns the whole state of the module instead.
Therefore, to address this issue, I create a dict named name_to_fp32_params
to store the mapping between the whole module state and the fp32 params.
# Set keep_vars=True to prevent the parameter to be detached
for name, p in module.state_dict(keep_vars=True).items():
if p in params_to_fp32_params[name] = params_to_fp32_params[p]
self.name_to_fp32_params[name] = params_to_fp32_params[p]
And in the method _save_to_state_dict()
, instead, utilize the function named_parameters()
, I use the self.module.state_dict()
to get the whole model state.
for name, p in self.module.state_dict().items():
if p is not None and name in self.name_to_fp32_params:
fp32_p = self.name_to_fp32_params[name]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
else:
record_parameter = p
destination[prefix + name] = record_parameter
With the modification shown above, the state_dict
can output the correct key value pair for the GPT2 without any impact on other models.
Thanks @eric8607242 , very impressive. Could you please solve the conflict? @1SAA please help review this PR.
@feifeibear Hi, Thanks for the quick response! The conflict fix done!
Hi @eric8607242
Thanks for your great PR. 😊
But we just added a powerful function called get_static_torch_model
which can generate a torch model from a GeminiDDP module. You can find the function here. Could you try this function like
state_dict = get_static_torch_model(model).state_dict()
?
Does this function meet your requirement? I think we could use this function to implement our new strict state_dict() funtion.
Hi @1SAA , Thanks for the great work and awesome function.
I think this is exactly what I want!
But the new function is coupled with the ZeroDDP.state_dict
.
Implementing the strict state_dict with the new function is a bit of a chicken-and-egg problem.
And I concern this may confuse other users as I did.
Hi @eric8607242
I think we can rename the current state_dict
function as _non_strict_state_dict
and write a new state_dict
function based on get_static_torch_model
. In my opinion, this way leads to minimum editions under the current circumstances.
Hi @1SAA , I think this is a great idea. Thanks for your suggestion!
I will follow this rule to modify the method as soon as possible.
Hi @1SAA,
I rewrite the state_dict
model with a new bool variable strict
to specify which version ofstate_dict
would be returned. The default value is True
to return the state_dict
as the Pytorch usage to avoid user confusion.
Look forward to your feedback. Thanks!
Hi @eric8607242
LGTM.
Hi @1SAA , Thanks for review!
I modify the argument GeminiDDP
module to ZeroDDP
module in get_static_torch_model
function to be compatible with the lower version (v0.1.9 and v0.1.10).