🐛 Describe the bug
I tried resuming training on a previous unsharded checkpoint from step 1k and the training resumed with no initial issue however when it tried to save the sharded checkpoint i encountered a error as shown below wondering what is causing this issue? For context, the env/node number used are all the same.
Traceback (most recent call last):
File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 345, in
main(cfg)
File "/mnt/azureml/cr/j/947c8b089dfe4d0484df42634f296716/exe/wd/scripts/train.py", line 316, in main
trainer.fit()
File "/workspace/OLMo/olmo/train.py", line 1153, in fit
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 560, in save_checkpoint
result = self.save_sharded_checkpoint()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 468, in save_sharded_checkpoint
result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/OLMo/olmo/train.py", line 428, in _save_checkpoint
checkpointer.save_checkpoint(
File "/workspace/OLMo/olmo/checkpoint.py", line 1000, in save_checkpoint
"optim": FSDP.optim_state_dict(dist_model, optim),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl
return _optim_state_dict(
^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict
fsdp_osd_state = convert_fn(
^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1415, in _convert_all_state_info
assert curr_scalar_tensor_value is None or torch.equal(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Rank 4 has different values for step: 1500.0. Other ranks: 500.0
Versions
.