accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Unable to specify HYBRID_SHARD for FSDP which requires process group or device_mesh to be passed

Open npuichigo opened this issue 1 year ago • 8 comments

System Info

- `Accelerate` version: 0.29.1
- Platform: Linux-5.19.0-46-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/yuchao/miniconda3/envs/TorchTTS/bin/accelerate
- Python version: 3.10.13
- Numpy version: 1.23.5
- PyTorch version (GPU?): 2.2.2+cu118 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 125.48 GB
- GPU type: NVIDIA GeForce RTX 4090
- `Accelerate` default config:
fsdp_plugin:
  sharding_strategy: 4

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • [X] My own task or dataset (give details below)

Reproduction

Specify fsdp strategy to ShardingStrategy.HYBRID_SHARD or _HYBRID_SHARD_ZERO2

File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/accelerate/accelerator.py", line 1434, in prepare_model
    return self.prepare_model(obj, device_placement=device_placement)
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/accelerate/accelerator.py", line 1434, in prepare_model
    model = FSDP(model, **kwargs)
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 448, in __init__
    model = FSDP(model, **kwargs)
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 448, in __init__
    _init_process_group_state(
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 118, in _init_process_group_state
    _init_process_group_state(
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 118, in _init_process_group_state
    raise ValueError(
ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')
    raise ValueError(
ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')

Expected behavior

Should provide a way to provide the process group or device_mesh as they're parameters for FSDP https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel

npuichigo avatar Apr 15 '24 10:04 npuichigo

I think the issue here is you're not setting an auto_wrap_policy, which is required? cc @pacman100 to confirm

muellerzr avatar Apr 15 '24 13:04 muellerzr

auto_wrap_policy may be another way. But I find no way to setup process_group and device_mesh which is in the interface of FSDP

torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)

npuichigo avatar Apr 15 '24 14:04 npuichigo

True, we need to enable this in the FullyShardedDataParallelPlugin, and make it work nicely when things are set via accelerate launch. (In which, this would not be a param capable of being set there, for obvious reasons)

muellerzr avatar Apr 15 '24 14:04 muellerzr

Sounds like one implication here is that transformer Trainer's --fsdp hybrid_shard / --fsdp hybrid_shard_zero2 can't possibly run successfully, since process_group is not amoung trainer arg

lidingsnyk avatar May 13 '24 13:05 lidingsnyk

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 07 '24 15:06 github-actions[bot]

any updates on this?

ParthaEth avatar Aug 14 '24 08:08 ParthaEth

any updates?

ShengYun-Peng avatar Nov 17 '24 01:11 ShengYun-Peng

If you need to use it in a hurry,

diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py
--- src/accelerate/accelerator.py
+++ src/accelerate/accelerator.py
@@ -1648,9 +1648,13 @@
                         "ignored_modules": fsdp_plugin.ignored_modules,
                         "limit_all_gathers": fsdp_plugin.limit_all_gathers,
                         "device_id": self.device,
                     }
-                    model = FSDP(model, **kwargs)
+                    if 'HYBRID_SHARD' in kwargs['sharding_strategy'].name:
+                        from torch.distributed.device_mesh import init_device_mesh
+                        model = FSDP(model, **kwargs, device_mesh=init_device_mesh(self.device.type, (int(os.environ['HYBRID_SHARD_CROSS_HOST_DIM']), int(os.environ['HYBRID_SHARD_WITHIN_HOST_DIM']))))
+                    else:
+                        model = FSDP(model, **kwargs)
                     if fsdp_plugin.activation_checkpointing:
                         from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
                             CheckpointImpl,
                             apply_activation_checkpointing,

loopback-kr avatar Apr 28 '25 16:04 loopback-kr

You can use FSDP2 with ParallelismConfig now, such as:

fsdp_plugin = FullyShardedDataParallelPlugin(fsdp_version=2,...)
parallelism_config = ParallelismConfig(dp_shard_size=4, dp_replicate_size=2)

FSDP1 is not actively developed anymore on our side, I recommend switching to FSDP2.

S1ro1 avatar Sep 16 '25 14:09 S1ro1

You can use FSDP2 with ParallelismConfig now, such as:

fsdp_plugin = FullyShardedDataParallelPlugin(fsdp_version=2,...)
parallelism_config = ParallelismConfig(dp_shard_size=4, dp_replicate_size=2)

FSDP1 is not actively developed anymore on our side, I recommend switching to FSDP2.

This is really helpful!

ZhiliangWu avatar Nov 12 '25 13:11 ZhiliangWu