Megatron-LM icon indicating copy to clipboard operation
Megatron-LM copied to clipboard

[BUG] Some checkpoint shards don't save / hang on multi-node setups, since v0.7

Open chotzen opened this issue 1 year ago • 1 comments

Describe the bug We're in the process of upgrading Megatron-Core from 0.6 to 0.8 and have noticed some problematic behavior with the new distributed async checkpoint saving introduced in mcore 0.7.

When trying to save a checkpoint for a small (say, llama 3 8b sized) model on a decently sized training setup (8 nodes, 2 way MP/32 way DDP) we notice that the checkpoint save hangs. Upon further inspection, we see that only 123 of the expected 128 checkpoint shards were saved to the directory. Which shards are missing varies with each attempt.

Notably, this does not happen on smaller setups -- on 1 or 2 nodes, this does not occur, but on 8 nodes, it occurs every time I've tested it.

To Reproduce Here is how we call the checkpointing code:

self._model: megatron.core.distributed.DistributedDataParallel.DDP
model_state_dict: Dict[str, Any] = {}
model_state_dict["model"] = self._model.sharded_state_dict()

save_strategy = get_default_save_sharded_strategy('torch_dist')
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, parallel_state.get_data_parallel_group(with_context_parallel=False), False)

torch.distributed.barrier()

_ = dist_checkpointing.save(
    model_state_dict,
    os.path.join(checkpoint_path, "model"),
    save_strategy,
    async_sharded_save=False,
)

We observe the same behavior with and without the FullyParallelSaveStrategyWrapper.

Expected behavior The checkpoint finishes saving

Stack trace/logs

On larger configurations (16 nodes, 2-way MP x 64-way DDP) we sometimes observe this error. I don't know if this is related to this issue or is a different one. For what it's worth, we don't observe this on the 8-node configuration.

"/app/exa/cluster_scheduler/megatron_job_image.binary.runfiles/exafunction/third_party/megatron_lm/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 394, in save
    sharded_strategy.save(sharded_state_dict, checkpoint_dir)
  File "/app/exa/cluster_scheduler/megatron_job_image.binary.runfiles/exafunction/third_party/megatron_lm/Megatron-LM/megatron/core/dist_checkpointing/strategies/base.py", line 185, in save
    async_calls.maybe_finalize_async_calls(blocking=True)
  File "/app/exa/cluster_scheduler/megatron_job_image.binary.runfiles/exafunction/third_party/megatron_lm/Megatron-LM/megatron/core/dist_checkpointing/strategies/async_utils.py", line 217, in maybe_finalize_async_calls
    finalize_fn()
  File "/app/exa/cluster_scheduler/megatron_job_image.binary.runfiles/exafunction/third_party/megatron_lm/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py", line 661, in finalize_fn
    save_state_dict_async_finalize(*save_state_dict_ret)
  File "/app/exa/cluster_scheduler/megatron_job_image.binary.runfiles/exafunction/third_party/megatron_lm/Megatron-LM/megatron/core/dist_checkpointing/strategies/state_dict_saver.py", line 144, in save_state_dict_async_finalize
    write_results = storage_writer.retrieve_write_results()
  File "/app/exa/cluster_scheduler/megatron_job_image.binary.runfiles/exafunction/third_party/megatron_lm/Megatron-LM/megatron/core/dist_checkpointing/strategies/filesystem_async.py", line 309, in retrieve_write_results
    raise RuntimeError(f'results_queue should not be empty')
RuntimeError: results_queue should not be empty

Environment (please complete the following information):

  • Megatron-LM commit ID: baf94af3c667248865f23df73b9fb8e2395e6fd0 + internal patches unrelated to checkpointing
  • PyTorch version 2.2.2
  • CUDA version 12.2.2
  • NCCL version 2.18

Additional context The volume we've saving to is a cephfs hosted on another node in our cluster. Checkpointing works with more or less the same setup on megatron v0.6.0 (before the new interface was implemented, we passed in the ('torch_dist', 1) strategy).

Please let me know what additional info I can provide.

chotzen avatar Sep 23 '24 22:09 chotzen

Hi, it might be the case the worker processes executing this function are failing, and the 16n log suggests this error might be silently ignored.

One thing that might help figure out what's going on is enabling debug logs by setting MEGATRON_LOGGING_LEVEL=10, can you try that? Please make sure the output from different ranks is distinguishable, e.g. by setting -l flag in the slurm run.


we see that only 123 of the expected 128 checkpoint shards were saved to the directory

just double checking, if you don't use FullyParallelSaveStrategyWrapper, then the directory is expected to have 4 shards (= 2 * 2 = MP * mutiprocessing_factor) - is it the case?

mikolajblaz avatar Sep 25 '24 11:09 mikolajblaz

Marking as stale. No activity in 60 days.

github-actions[bot] avatar Nov 24 '24 18:11 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jul 31 '25 02:07 github-actions[bot]