Megatron-LM icon indicating copy to clipboard operation
Megatron-LM copied to clipboard

[BUG] Can not load _extra_state with TorchDistLoadShardedStrategy

Open ZetangForward opened this issue 1 year ago • 1 comments

Describe the bug I used the TorchDistLoadShardedStrategy loading strategy to load model weights in the distcp format. There are two different formats involved: ShardedTensor and ShardedObject. The former stores sliced weights, while the latter contains _extra_state, which holds some FP8-related information about the weights. The issue is that when the TorchDistLoadShardedStrategy attempts to read a ShardedObject, it throws an error.

Image

I notice the code here:

checkpoint.load(
        pyt_state_dict,
        FileSystemReader(checkpoint_dir),
        planner=MCoreLoadPlanner(
            shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
        ),
    )

Megatron uses a binary method to read model weights, loading the results from checkpoint_dir and overwriting them into sharded_state_dict. However, it seems that it can only read content of the ShardedTensor type and is unable to read the contents stored in ShardedObject.

This results in the extra_state_dict being loaded as some io.BytesIO objects instead of the expected structured data.

ZetangForward avatar Mar 23 '25 05:03 ZetangForward

Marking as stale. No activity in 60 days.

github-actions[bot] avatar May 22 '25 18:05 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Jul 28 '25 02:07 github-actions[bot]