ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

ZeRO dose not initialize weight correctly

Open feifeibear opened this issue 3 years ago • 5 comments

🐛 Describe the bug

I consume the zero dose not initialize model parameter correctly.

  1. In ZeroInitContext, we adapt torch param to type ShardedParamV2 when an param is constructed. In this way, param.col_attr.payload is used to maintain the payload of param.data. If rm_torch_payload_on_the_fly is True, we remove param.data immediately by setting it as a dummy tensor. https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/zero/init_ctx/init_context.py#L165

  2. After the param is constructed. Some models may init the param content at the end of init. https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L723 It assigns values to the param.data. However, at this time, param.data is set to a dummy tensor. The param.col_attr.payload is not set correctly.

Environment

No response

feifeibear avatar Mar 29 '22 03:03 feifeibear

Work in process

binmakeswell avatar Apr 13 '22 05:04 binmakeswell

@1SAA did you fix the issue?

feifeibear avatar Apr 13 '22 05:04 feifeibear

I have fixed a part of this problem. If users only use the initializing functions in torch.nn.init, there will be no problem. But correct initialization can't be assured when users use initializing functions written by their own.

1SAA avatar Apr 15 '22 06:04 1SAA

Could please get an example to illustrate the not supported cases! @1SAA

feifeibear avatar Apr 15 '22 06:04 feifeibear

When the initializing functions come with the shape of data tensor, it causes a problem, such as the code below.

def my_init_func1(tensor):
       fan_in = tensor.size(0)
       fan_out = tensor.size(1)
       nn.init.trunc_normal_(tensor, 2 / (fan_in + fan_out))
       
def my_init_func2(tensor):
       for i in range(tensor.size(0)):
           nn.init.uniform(tensor[i])
       
class Net(nn.Module):
       
      def __init__(self):
           super(Net, self).__init__()
           self.proj = nn.Linear(16, 16)
           ...
           my_init_func1(self.proj.weight)
           my_init_func2(self.proj.bias)

If you really want to calculate fan_in and fan_out, please use nn.init._calculate_fan_in_and_fan_out. Here is an exmple of correct init codes.

def my_init_func(tensor):
        fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
        nn.init.trunc_normal_(tensor, 2 / (fan_in + fan_out))

class Net(nn.Module):

        def __init__(self):
            super(Net, self).__init__()
            self.proj = nn.Linear(16, 16)
            ...
            my_init_func(self.proj.weight)
            nn.init.xavier_uniform(self.proj.bias)

Remember that you should not use .data for any of your parameters. Just use those paramters as torch.Tensor. But do not try to use its shape attribute.

1SAA avatar Apr 15 '22 07:04 1SAA

We have updated a lot. This issue was closed due to inactivity. Thanks.

binmakeswell avatar Apr 13 '23 03:04 binmakeswell