accelerate
accelerate copied to clipboard
Unable to specify HYBRID_SHARD for FSDP which requires process group or device_mesh to be passed
System Info
- `Accelerate` version: 0.29.1
- Platform: Linux-5.19.0-46-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/yuchao/miniconda3/envs/TorchTTS/bin/accelerate
- Python version: 3.10.13
- Numpy version: 1.23.5
- PyTorch version (GPU?): 2.2.2+cu118 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 125.48 GB
- GPU type: NVIDIA GeForce RTX 4090
- `Accelerate` default config:
fsdp_plugin:
sharding_strategy: 4
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
Specify fsdp strategy to ShardingStrategy.HYBRID_SHARD or _HYBRID_SHARD_ZERO2
File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/accelerate/accelerator.py", line 1434, in prepare_model
return self.prepare_model(obj, device_placement=device_placement)
File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/accelerate/accelerator.py", line 1434, in prepare_model
model = FSDP(model, **kwargs)
File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 448, in __init__
model = FSDP(model, **kwargs)
File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 448, in __init__
_init_process_group_state(
File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 118, in _init_process_group_state
_init_process_group_state(
File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 118, in _init_process_group_state
raise ValueError(
ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')
raise ValueError(
ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')
Expected behavior
Should provide a way to provide the process group or device_mesh as they're parameters for FSDP https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel
I think the issue here is you're not setting an auto_wrap_policy, which is required? cc @pacman100 to confirm
auto_wrap_policy may be another way. But I find no way to setup process_group and device_mesh which is in the interface of FSDP
torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)
True, we need to enable this in the FullyShardedDataParallelPlugin, and make it work nicely when things are set via accelerate launch. (In which, this would not be a param capable of being set there, for obvious reasons)
Sounds like one implication here is that transformer Trainer's --fsdp hybrid_shard / --fsdp hybrid_shard_zero2 can't possibly run successfully, since process_group is not amoung trainer arg
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
any updates on this?
any updates?
If you need to use it in a hurry,
diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py
--- src/accelerate/accelerator.py
+++ src/accelerate/accelerator.py
@@ -1648,9 +1648,13 @@
"ignored_modules": fsdp_plugin.ignored_modules,
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
"device_id": self.device,
}
- model = FSDP(model, **kwargs)
+ if 'HYBRID_SHARD' in kwargs['sharding_strategy'].name:
+ from torch.distributed.device_mesh import init_device_mesh
+ model = FSDP(model, **kwargs, device_mesh=init_device_mesh(self.device.type, (int(os.environ['HYBRID_SHARD_CROSS_HOST_DIM']), int(os.environ['HYBRID_SHARD_WITHIN_HOST_DIM']))))
+ else:
+ model = FSDP(model, **kwargs)
if fsdp_plugin.activation_checkpointing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
You can use FSDP2 with ParallelismConfig now, such as:
fsdp_plugin = FullyShardedDataParallelPlugin(fsdp_version=2,...)
parallelism_config = ParallelismConfig(dp_shard_size=4, dp_replicate_size=2)
FSDP1 is not actively developed anymore on our side, I recommend switching to FSDP2.
You can use
FSDP2withParallelismConfignow, such as:fsdp_plugin = FullyShardedDataParallelPlugin(fsdp_version=2,...) parallelism_config = ParallelismConfig(dp_shard_size=4, dp_replicate_size=2)FSDP1 is not actively developed anymore on our side, I recommend switching to FSDP2.
This is really helpful!