accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

FSDP unable to load checkpoint, state dict, saved weights

Open conceptofmind opened this issue 2 years ago • 1 comments

System Info

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_process_ip: ''
main_process_port: ''
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 8
num_processes: 64
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

  • [X] The official example scripts
  • [X] My own modified scripts

Tasks

  • [X] One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • [X] My own task or dataset (give details below)

Reproduction

Hi all,

When trying to resume from a model checkpoint with use_orig_params=True:

        if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
            accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
            accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)

I am receiving this error:

    main()
  File "/train_distributed.py", line 545, in main
    accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
  File "/lib/python3.9/site-packages/accelerate/accelerator.py", line 2439, in load_state
    self.state.fsdp_plugin.load_optimizer(self, opt, self._models[i], input_dir, i)
  File "/lib/python3.9/site-packages/accelerate/utils/dataclasses.py", line 951, in load_optimizer
    sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
  File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1581, in scatter_full_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1236, in _optim_state_dict_to_load_impl
    flat_osd = _flatten_optim_state_dict(
  File "/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 359, in _flatten_optim_state_dict
    assert (
AssertionError: If use_orig_params is True, shard_state must be True.

I opened an issue with the PyTorch team who pointed out that it seems the "save method in HF accelerate for FSDP sharded state dict needs some updates. The sharded state dict save/load can be found here. This also extends to saving the optimizer states. The sharded optimizer save/load can be found here."

Thank you for the help!

Enrico

Expected behavior

Load checkpoint, state dict, saved weights when using FSDP.

conceptofmind avatar May 22 '23 02:05 conceptofmind

cc @pacman100

sgugger avatar May 22 '23 13:05 sgugger

Hello @conceptofmind, the above PR should fix this. Choose state_dict_type as SHARDED_STATE_DICT when answering questionnaire post running the command accelerate config.

Please let us know if this resolves the issue

pacman100 avatar Jun 13 '23 19:06 pacman100

Hi @pacman100 ,

Thank you for the response. I will test out loading and saving the models with FSDP.

Best,

Enrico

conceptofmind avatar Jun 13 '23 19:06 conceptofmind

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 08 '23 15:07 github-actions[bot]