accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

PyTorch FSDP with Accelerate

Open conceptofmind opened this issue 2 years ago • 3 comments

Hi all,

I was wondering if you could give any input on whether the standard PyTorch FSDP wrapper was compatible with Huggingface accelerate.prepare()?

For example:

import torch
from accelerate import Accelerator

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
)

accelerator = Accelerator(bf16=True)

model = MyModel().to(accelerator.device)
fsdp_model = FullyShardedDataParallel(model)
fsdp_model = accelerator.prepare(fsdp_model)

optimizer = MyOpt()

optimizer, train_dataloader, val_dataloader = accelerator.prepare(
    optimizer, train_dataloader, val_dataloader
)

Thank you,

Enrico

conceptofmind avatar May 10 '23 16:05 conceptofmind

Yes, if the model is already wrapped in FullyShardedDataParallel , accelerator.prepare will just return the same

pacman100 avatar May 10 '23 17:05 pacman100

Yes, if the model is already wrapped in FullyShardedDataParallel , accelerator.prepare will just return the same

Thank you for the immediate response.

I currently have the accelerate config as:

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config: {}
gpu_ids: all
machine_rank: 0
main_process_ip: ''
main_process_port: ''
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 8
num_processes: 64
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Would this configuration require any alterations to work with PyTorch FullyShardedDataParallel?

Does Mixed Precision need to be set as bf16 in the FSDP wrapper? Or does applying it in the Accelerate already handle this case? Does anything need to be added to fsdp_config if it is already set in the code?

   check_fn = lambda submodule: isinstance(submodule, ParallelTransformerBlock)

    non_reentrant_wrapper = partial(
        checkpoint_wrapper,
        offload_to_cpu=False,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )

    palm_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            ParallelTransformerBlock,
        },
    )

    bf16_fsdp = MixedPrecision(
        param_dtype=torch.bfloat16,
        # Gradient communication precision.
        reduce_dtype=torch.bfloat16,
        # Buffer precision.
        buffer_dtype=torch.bfloat16,
    )

    model = FullyShardedDataParallel(
        model,
        auto_wrap_policy=palm_auto_wrap_policy,
        mixed_precision=bf16_fsdp,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        forward_prefetch=True,
        use_orig_params=True,
    )

    apply_activation_checkpointing(
        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
    )

    model = accelerator.prepare(model) 

Thank you,

Enrico

conceptofmind avatar May 10 '23 17:05 conceptofmind

Additionally, I am having an issue when trying to resume from a model checkpoint with use_orig_params=True:

        if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
            accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
            accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)

I am receiving this error:

    main()
  File "/train_distributed.py", line 545, in main
    accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
  File "/lib/python3.9/site-packages/accelerate/accelerator.py", line 2439, in load_state
    self.state.fsdp_plugin.load_optimizer(self, opt, self._models[i], input_dir, i)
  File "/lib/python3.9/site-packages/accelerate/utils/dataclasses.py", line 951, in load_optimizer
    sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
  File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1581, in scatter_full_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1236, in _optim_state_dict_to_load_impl
    flat_osd = _flatten_optim_state_dict(
  File "/lib/python3.9/site-packages/torch/distributed/fsdp/_optim_utils.py", line 359, in _flatten_optim_state_dict
    assert (
AssertionError: If use_orig_params is True, shard_state must be True.

I am unsure where shard_state should be set to True as I do not see it in the FSDP documentation.

Thank you,

Enrico

conceptofmind avatar May 17 '23 19:05 conceptofmind

This has not been resolved and the PR addressing it was automatically closed.

conceptofmind avatar Jun 11 '23 16:06 conceptofmind

Hello @conceptofmind, accelerate returns the model as is if it is already an instance of FullyShardedDataParallel . Refer this line: https://github.com/huggingface/accelerate/blob/50eabe5b1d3cde24a9f3b65b6e7e25075b269da4/src/accelerate/accelerator.py#L1302

Regarding the use_orig_params, replied on the relevant issues

pacman100 avatar Jun 13 '23 19:06 pacman100

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 08 '23 15:07 github-actions[bot]