accelerate
accelerate copied to clipboard
FSDP unable to load checkpoint, state dict, saved weights
System Info
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_process_ip: ''
main_process_port: ''
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 8
num_processes: 64
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [X] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
Hi all,
When trying to resume from a model checkpoint with use_orig_params=True:
if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
I am receiving this error:
main()
File "/train_distributed.py", line 545, in main
accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
File "/lib/python3.9/site-packages/accelerate/accelerator.py", line 2439, in load_state
self.state.fsdp_plugin.load_optimizer(self, opt, self._models[i], input_dir, i)
File "/lib/python3.9/site-packages/accelerate/utils/dataclasses.py", line 951, in load_optimizer
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1581, in scatter_full_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1236, in _optim_state_dict_to_load_impl
flat_osd = _flatten_optim_state_dict(
File "/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 359, in _flatten_optim_state_dict
assert (
AssertionError: If use_orig_params is True, shard_state must be True.
I opened an issue with the PyTorch team who pointed out that it seems the "save method in HF accelerate for FSDP sharded state dict needs some updates. The sharded state dict save/load can be found here. This also extends to saving the optimizer states. The sharded optimizer save/load can be found here."
Thank you for the help!
Enrico
Expected behavior
Load checkpoint, state dict, saved weights when using FSDP.
cc @pacman100
Hello @conceptofmind, the above PR should fix this. Choose state_dict_type as SHARDED_STATE_DICT when answering questionnaire post running the command accelerate config.
Please let us know if this resolves the issue
Hi @pacman100 ,
Thank you for the response. I will test out loading and saving the models with FSDP.
Best,
Enrico
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.