ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: SaveCheckpointHook does not save optimizer and scheduler parameters

Open liuslnlp opened this issue 2 years ago β€’ 9 comments

πŸ› Describe the bug

As mentioned in #2569, I try to use SaveCheckpointHook to save the checkpoint of titans GPT in hybrid parallel training. However, only the model state is saved, not the optimizer and lr scheduler states.

This hook calls function colossalai.utils.checkpointing.save_checkpoint (line 154-190 of colossalai/utils/checkpointing.py), but it does not implement saving of optimizer parameters:

def save_checkpoint(file,
                    epoch: int,
                    model: torch.nn.Module,
                    optimizer: torch.optim.Optimizer = None,
                    lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
                    **kwargs):
    """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
    lr_scheduler etc. into a checkpoint dictionary.

    Args:
        file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a
            file name.
        epoch (int): Epoch number (indicates how many epochs have you trained this model).
        model (:class:`torch.nn.Module`): Model to be saved.
        optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
        lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
            lr_scheduler to be saved, defaults to None.
        pickle_module: module used for pickling metadata and objects
        pickle_protocol: can be specified to override the default protocol
    """
    # ckpt container
    checkpoint = {"epoch": epoch}

    model_state = model.state_dict()
    if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
        model_state = gather_pipeline_parallel_state_dict(model_state)

    if gpc.get_global_rank() == 0:
        checkpoint["model"] = model_state

        # if optimizer is not None:
        #     checkpoint['optimizer'] = optimizer.state_dict()

        # if lr_scheduler is not None:
        #     checkpoint['lr_scheduler'] = lr_scheduler.state_dict()

        torch.save(checkpoint, file, **kwargs)

So I'm wondering how can I save and load the parameters of the optimizer under hybrid parallel and zero3 (examples/language/gpt/titans/train_gpt.py).

Environment

The same with examples/language/gpt/titans/requirements.txt.

liuslnlp avatar Feb 08 '23 12:02 liuslnlp

Please first uncomment codes at here.

Then append a hook such as hooks.SaveCheckpointHook(10, checkpoint_dir='./ckpt', model=trainer.engine.model, save_by_iter=True) to here.

To load a checkpoint, open a python3 session, and run:

>> import torch
>> ckpt = torch.load('./ckpt')
>> print(ckpt.keys())

JThh avatar Feb 08 '23 16:02 JThh

Hello @JThh, after uncomment the following codes, only rank 0 can save the optimizer states.

However, due to the zero3 mechanism, each worker only keeps part of the optimizer states. So this method can only save 1 / world_size of the optimizer parameters.

    if gpc.get_global_rank() == 0:
        checkpoint["model"] = model_state

        # if optimizer is not None:
        #     checkpoint['optimizer'] = optimizer.state_dict()

        # if lr_scheduler is not None:
        #     checkpoint['lr_scheduler'] = lr_scheduler.state_dict()

        torch.save(checkpoint, file, **kwargs)

liuslnlp avatar Feb 09 '23 05:02 liuslnlp

Hi @liuslnlp , your point is valid. Can you try this?

JThh avatar Feb 09 '23 10:02 JThh

@JThh, this method has the same effect as the previous one.

liuslnlp avatar Feb 10 '23 10:02 liuslnlp

No. This method should be able to gather optimiser states before saving.

JThh avatar Feb 10 '23 17:02 JThh

So in order to save titans GPT, which method should we use?

JingxinLee avatar Mar 14 '23 09:03 JingxinLee

This medthod does not work with ChatGPT sample + ZeRO2

Hi @liuslnlp , your point is valid. Can you try this?

hijkzzz avatar Mar 27 '23 00:03 hijkzzz

Bot detected the issue body's language is not English, translate it automatically. πŸ‘―πŸ‘­πŸ»πŸ§‘β€πŸ€β€πŸ§‘πŸ‘«πŸ§‘πŸΏβ€πŸ€β€πŸ§‘πŸ»πŸ‘©πŸΎβ€πŸ€β€πŸ‘¨πŸΏπŸ‘¬πŸΏ


This medthod does not work with ChatGPT sample + ZeRO2

Hi @liuslnlp , your point is valid. Can you try this?

Issues-translate-bot avatar Mar 27 '23 00:03 Issues-translate-bot

Hi @hijkzzz you can refer to https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat#faq for ChatGPT example. For more general cases, a #new checkponint system will come soon. Thanks.

binmakeswell avatar Apr 18 '23 09:04 binmakeswell