Save intermediate checkpoints during training
Is there any way to save intermediate checkpoints during training?
Sometimes my training may fail during the middle due to external reasons, it will be helpful to save every N steps so I can continue where I left off.
Another use-case is I may find the model overfitting at the end of the epoch, that way I can use one of the intermediate steps
Hey @l3utterfly - thanks for creating the issue!
This is definitely something that's on our minds, but is not currently available in the repo yet. Right now, if I'm worried that my training might take a long time and therefore is susceptible to failures, I split the dataset up into smaller pieces using the built-in datasets % operator.
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
split: train[:25%]
This will train on 25% of a given dataset. Then, I take that checkpoint and pass it into the next training round with resume_from_checkpoint: True. Rinse and repeat 4 times to complete training on the entire dataset.
I fully acknowledge this isn't the most efficient way to handle this and we are looking into adding this feature, but the above method should at least unblock you until now.
I want to add to Joe's response, another option to unblock yourself would be to copy the recipe you're using and copy and paste the save checkpoint line into the training for loop with an if statement that allows it to save every n steps.
@joecummings I am trying with the split: train [:25%] method.
My first 25% is done training, I see the PT files written to my output dir. However, I do not see a recipe_state.pt file generated, causing me to be unable to resume.
How can I generate this file?
@joecummings I am trying with the
split: train [:25%]method.My first 25% is done training, I see the PT files written to my output dir. However, I do not see a
recipe_state.ptfile generated, causing me to be unable to resume.How can I generate this file?
Ahh yes, our default recipes only generate that file for intermediate checkpoints (https://github.com/pytorch/torchtune/blob/15c918d65d79e03abcbd0c5e94d2b116bd368412/torchtune/utils/_checkpointing/_checkpointer.py#L590), e.g for epoch 1 of 2. So you can either:
- Specify extra epochs if you want to use the default recipe and then kill training after a single epoch or
-
tune cpthe recipe you want and modify the code so thatsave_checkpoint(..., intermediate_checkpoint=True)
I wish it were possible to take your current assets and use them to continue training ATM, but recipe_state.pt also contains the current optimizer state, which can significantly change your training. Apologies for the confusion!
dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset split: train[:25%]
@joecummings am I totally tripping out or is this not supported anymore? Do our dataset builders not hardcore split=train?
TypeError: alpaca_dataset() got an unexpected keyword argument 'split'
They do hardcode split=train but it's not exposed in the alpaca_dataset builder itself. I'm supportive of making this small change.
They do hardcode split=train but it's not exposed in the alpaca_dataset builder itself. I'm supportive of making this small change.
How would you feel about it being exposed in general across our dataset builders?
How would you feel about it being exposed in general across our dataset builders?
sounds good to me
I want to add to Joe's response, another option to unblock yourself would be to copy the recipe you're using and copy and paste the save checkpoint line into the training for loop with an if statement that allows it to save every n steps.
I followed the approach to try to define a custom recipe with mid-epoch checkpointing. Unfortunately, just using the save checkpoint method to save every n steps does not seem to work for distributed finetuning, both with LORA or Full finetuning. Whenever a checkpoint is saved after n steps, the training loop fails to resume. GPUs timeout with the following error.
[rank1]:[E ProcessGroupNCCL.cpp:563] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=313, OpType=_ALLGATHER_BASE, NumelIn=262668288, NumelOut=525336576, Timeout(ms)=600000) ran for 600010 milliseconds before timing out.
[rank0]:[E ProcessGroupNCCL.cpp:563] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=315, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600095 milliseconds before timing out.
[rank0]:[E ProcessGroupNCCL.cpp:1537] [PG 0 Rank 0] Timeout at NCCL work: 315, last enqueued NCCL work: 315, last completed NCCL work: 314.
[rank0]:[E ProcessGroupNCCL.cpp:577] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E ProcessGroupNCCL.cpp:583] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=315, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600095 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:565 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7ffa74b1f897 in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7ffa75dfa1b2 in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1a0 (0x7ffa75dfefd0 in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7ffa75e0031c in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xdbbf4 (0x7ffac18a6bf4 in /home/anaconda3/envs/llm_env/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x8609 (0x7ffac2e1f609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x43 (0x7ffac2bea353 in /lib/x86_64-linux-gnu/libc.so.6)
[rank1]:[E ProcessGroupNCCL.cpp:1537] [PG 0 Rank 1] Timeout at NCCL work: 313, last enqueued NCCL work: 317, last completed NCCL work: 312.
[rank1]:[E ProcessGroupNCCL.cpp:577] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E ProcessGroupNCCL.cpp:583] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E ProcessGroupNCCL.cpp:1414] [PG 0 Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=313, OpType=_ALLGATHER_BASE, NumelIn=262668288, NumelOut=525336576, Timeout(ms)=600000) ran for 600010 milliseconds before timing out.
Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:565 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f20726af897 in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x1d2 (0x7f207398a1b2 in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1a0 (0x7f207398efd0 in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f207399031c in /home/anaconda3/envs/llm_env/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xdbbf4 (0x7f20bf436bf4 in /home/anaconda3/envs/llm_env/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x8609 (0x7f20c09af609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #6: clone + 0x43 (0x7f20c077a353 in /lib/x86_64-linux-gnu/libc.so.6)
Any thoughts of what might be causing this problem? At the beginning I thought that saving the full_state_dict was expensive during training so I also tried with distributed checkpointing with SHARD_GRAD_OP strategy but I get to the same point.
@AugustoCapone can you share the modifications you made to save the checkpoint mid-epoch? One quick hack I can suggest is to call torch.distributed.barrier() after checkpoint save to make sure all ranks are ready before you resume training, but depending on how you save the checkpoint this should already be handled by the underlying APIs.
@AugustoCapone can you share the modifications you made to save the checkpoint mid-epoch? One quick hack I can suggest is to call
torch.distributed.barrier()after checkpoint save to make sure all ranks are ready before you resume training, but depending on how you save the checkpoint this should already be handled by the underlying APIs.
The first implementation for full finetuning and full_shard looks like the following:
def save_checkpoint(
self,
epoch: int,
step: Optional[int] = None,
) -> None:
"""
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
"""
checkpoint_dict = {}
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:
checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict})
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
checkpoint_dict.update(
{
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
step=step,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)
First I modified the save_checkpoint method in the recipe and in _checkpointer.py to also include the step count in the generated file names. I also removed the optimizer state_dict to run the first tests.
Then I modified the training loop with the following condition:
...
if loss_to_log < best_loss:
best_loss = loss_to_log
self.save_checkpoint(
epoch=curr_epoch,
step=self.global_step,
)
For this particular situation, I'm not sure if torch.distributed.barrier() works to syncrhonize all processes.