ZeRO dose not initialize weight correctly
🐛 Describe the bug
I consume the zero dose not initialize model parameter correctly.
-
In ZeroInitContext, we adapt torch param to type ShardedParamV2 when an param is constructed. In this way, param.col_attr.payload is used to maintain the payload of param.data. If rm_torch_payload_on_the_fly is True, we remove param.data immediately by setting it as a dummy tensor. https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/zero/init_ctx/init_context.py#L165
-
After the param is constructed. Some models may init the param content at the end of init. https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L723 It assigns values to the param.data. However, at this time, param.data is set to a dummy tensor. The param.col_attr.payload is not set correctly.
Environment
No response
Work in process
@1SAA did you fix the issue?
I have fixed a part of this problem. If users only use the initializing functions in torch.nn.init, there will be no problem. But correct initialization can't be assured when users use initializing functions written by their own.
Could please get an example to illustrate the not supported cases! @1SAA
When the initializing functions come with the shape of data tensor, it causes a problem, such as the code below.
def my_init_func1(tensor):
fan_in = tensor.size(0)
fan_out = tensor.size(1)
nn.init.trunc_normal_(tensor, 2 / (fan_in + fan_out))
def my_init_func2(tensor):
for i in range(tensor.size(0)):
nn.init.uniform(tensor[i])
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.proj = nn.Linear(16, 16)
...
my_init_func1(self.proj.weight)
my_init_func2(self.proj.bias)
If you really want to calculate fan_in and fan_out, please use nn.init._calculate_fan_in_and_fan_out.
Here is an exmple of correct init codes.
def my_init_func(tensor):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
nn.init.trunc_normal_(tensor, 2 / (fan_in + fan_out))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.proj = nn.Linear(16, 16)
...
my_init_func(self.proj.weight)
nn.init.xavier_uniform(self.proj.bias)
Remember that you should not use .data for any of your parameters.
Just use those paramters as torch.Tensor.
But do not try to use its shape attribute.
We have updated a lot. This issue was closed due to inactivity. Thanks.