SLAM-LLM
SLAM-LLM copied to clipboard
FSDP training raise "KeyError: 'ShardingStrategy.NO_SHARD'"
System Info
torch 2.0.1 torchaudio 2.0.2 torchvision 0.15.2
Information
- [ ] The official example scripts
- [ ] My own modified scripts
🐛 Describe the bug
Hi, i can train the asr_librispeech finetuning code use DDP, however, when i switch to FSDP, an exception raised.
Error logs
Traceback (most recent call last): File "examples/asr_librispeech/finetune_asr.py", line 41, in main_hydra train(kwargs) File "/SLAM-LLM/src/slam_llm/pipeline/finetune.py", line 167, in main model = FSDP( File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 391, in init _auto_wrap(auto_wrap_kwargs, fsdp_kwargs, FullyShardedDataParallel) File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 73, in _auto_wrap _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/wrap.py", line 370, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/wrap.py", line 370, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/wrap.py", line 370, in _recursive_wrap wrapped_child, num_wrapped_params = _recursive_wrap( [Previous line repeated 3 more times] File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/wrap.py", line 388, in _recursive_wrap return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/wrap.py", line 317, in _wrap return wrapper_cls(module, **kwargs) File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in init _init_param_handle_from_module( File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/_init_utils.py", line 429, in _init_param_handle_from_module _init_param_handle_from_params(state, managed_params, fully_sharded_module) File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/_init_utils.py", line 529, in _init_param_handle_from_params SHARDING_STRATEGY_MAP[state.sharding_strategy], KeyError: 'ShardingStrategy.NO_SHARD'
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2031656) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in
Expected behavior
How to modify to use FSDP for speeding up? Thanks a lot! :D