Megatron-LM
Megatron-LM copied to clipboard
[BUG] bugs when using pytorch2.4.0 to run run_simple_mcore_train_loop.py
Describe the bug bugs when using pytorch2.4.0 to run run_simple_mcore_train_loop.py
To Reproduce
PYTHONPATH=$PYTHON_PATH:./megatron torchrun --nproc-per-node 2 examples/run_simple_mcore_train_loop.py
Stack trace/logs
[rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/megatron/examples/run_simple_mcore_train_loop.py", line 152, in <module>
[rank1]: save_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path)
[rank1]: File "/workspace/megatron/examples/run_simple_mcore_train_loop.py", line 109, in save_distributed_checkpoint
[rank1]: dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path)
[rank1]: File "/workspace/megatron/megatron/core/dist_checkpointing/serialization.py", line 393, in save
[rank1]: sharded_strategy.save(sharded_state_dict, checkpoint_dir)
[rank1]: File "/workspace/megatron/megatron/core/dist_checkpointing/strategies/base.py", line 180, in save
[rank1]: async_request = self.async_save(sharded_state_dict, checkpoint_dir)
[rank1]: File "/workspace/megatron/megatron/core/dist_checkpointing/strategies/torch.py", line 632, in async_save
[rank1]: ) = save_state_dict_async_plan(
[rank1]: File "/workspace/megatron/megatron/core/dist_checkpointing/strategies/state_dict_saver.py", line 108, in save_state_dict_async_plan
[rank1]: central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/utils.py", line 191, in reduce_scatter
[rank1]: raise result
[rank1]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
[rank1]: Traceback (most recent call last): (RANK 0)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/utils.py", line 179, in reduce_scatter
[rank1]: reduce_fun(cast(List[T], all_data)),
[rank1]: File "/workspace/megatron/megatron/core/dist_checkpointing/strategies/state_dict_saver.py", line 97, in global_step
[rank1]: all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
[rank1]: File "/workspace/megatron/megatron/core/dist_checkpointing/strategies/torch.py", line 465, in create_global_plan
[rank1]: global_plan, metadata = super().create_global_plan(all_plans)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/default_planner.py", line 109, in create_global_plan
[rank1]: global_plan, metadata = create_default_global_save_plan(all_plans)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/default_planner.py", line 388, in create_default_global_save_plan
[rank1]: assert item.index.fqn not in md
[rank1]: AssertionError
Environment (please complete the following information):
- Megatron-LM commit ID 6bf8448
- nvcr.io/nvidia/pytorch:24.07-py3
Proposed fix PR https://github.com/NVIDIA/Megatron-LM/pull/1004
@1195343015 @michal2409 Is adding this fix to the library. Thanks