ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: Parameters missing in the state_dict output of ZeroDPP module

Open eric8607242 opened this issue 2 years ago • 1 comments

🐛 Describe the bug

Hello, I currently fine-tune the Huggingface GPT2 with ColossalAI. I follow the example with GeminiDPP and ZeroOptimizer. However, I found that there are some keys missing issue when I load the checkpoint stored by the ColossalAI to the GPT2 model.

The following codebase can reproduce the issue:

import colossalai
from colossalai.tensor import ProcessGroup, ShardSpec, ColoParameter
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import save_checkpoint
from colossalai.utils.model.colo_init_context import ColoInitContext

from transformers import GPT2LMHeadModel
import torch

if __name__ == "__main__":
    device = "cuda:0"
    tp_degree = 1
    shardinit = True
    placement_policy = "cpu"
    path_to_checkpoint = "./test.pth"

    colossalai.launch_from_torch(config={})

    default_pg = ProcessGroup(tp_degree=tp_degree)
    default_dist_spec = ShardSpec([-1], [tp_degree]) if shardinit else None
    with ColoInitContext(
        device,
        dtype=torch.half,
        default_dist_spec=default_dist_spec,
        default_pg=default_pg
    ):
        model = GPT2LMHeadModel.from_pretrained("gpt2")

    model = GeminiDDP(
        model, device=device,
        placement_policy=placement_policy,
        pin_memory=True, hidden_dim=768,
        search_range_mb=64
    )

    save_checkpoint(path_to_checkpoint, epoch=0, model=model)

    original_model = GPT2LMHeadModel.from_pretrained("gpt2")
    checkpoint = torch.load(path_to_checkpoint)
    original_model.load_state_dict(checkpoint["model"])

And the error message is shown as the follow:

RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel:
        Missing key(s) in state_dict: "lm_head.weight". 

Environment

Os: Ubuntu 22.04

GPU: NVIDIA GeForce RTX 3090

Package list: pytorch 2.0.0_cuda11.6_cudnn8.3.2_0 cuda-toolkit 11.6.1
colossalai 0.2.0+torch2.0cu11.7

eric8607242 avatar Jan 06 '23 05:01 eric8607242

Hi @eric8607242

you can solve this problem by just adding a argument in your load_state_dict function. write like this

    original_model.load_state_dict(checkpoint["model"], strict=False)

1SAA avatar Jan 06 '23 06:01 1SAA