[BUG] Can not load _extra_state with TorchDistLoadShardedStrategy
Describe the bug
I used the TorchDistLoadShardedStrategy loading strategy to load model weights in the distcp format. There are two different formats involved: ShardedTensor and ShardedObject. The former stores sliced weights, while the latter contains _extra_state, which holds some FP8-related information about the weights. The issue is that when the TorchDistLoadShardedStrategy attempts to read a ShardedObject, it throws an error.
I notice the code here:
checkpoint.load(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
Megatron uses a binary method to read model weights, loading the results from checkpoint_dir and overwriting them into sharded_state_dict. However, it seems that it can only read content of the ShardedTensor type and is unable to read the contents stored in ShardedObject.
This results in the extra_state_dict being loaded as some io.BytesIO objects instead of the expected structured data.
Marking as stale. No activity in 60 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.