[Feature Request] Support FP8 mixed precision with FSDP Plugin
System Info
accelerate == 0.22 or 0.23.dev (build from main)
transformers == 4.33 or 4.34.dev (build from main)
transformer-engine == 0.11.0
torch == 2.1.1
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
I'm trying to launch multi-node multi-gpu Llama-2 for continued pretraining. My training script is using Accelerate to setup distributed environment and HF Transfomers Trainer to execute the training loop. I'd like to use FP8 precision with FSDP plugin, but seeing issues.
Below are some details on how to reproduce the issue. In my example I omitted some custom code which distributed the tasks and preparers the data to make it more simple. Let me know if any key details are missing.
I start training script with following command line which runs on each machine in multi-node environment):
torch.distributed.run -m accelerate.commands.launch -- main_process_ip=$(MASTER_ADDR) --main_process_port=2940 --mixed_precision=fp8|bf16
--rdzv_backend=c10d --machine_rank=$(RANK) --num_machines=$(WORLD_SIZE) --num_processes=4 --use_fsdp --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP'
--fsdp_backward_prefetch_policy=BACKWARD_PRE --fsdp_offload_params=false --fsdp_sharding_strategy=1 --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_transformer_layer_cls_to_wrap=LlamaDecoderLayer --module - train_module.train <training script args>
where train_module.train() is a custom wrapper on top of HuggingFace Trainer class with minimal changes to it.
When running my script with --mixed_precision=bf16, the script works as expected, the model is successfully sharded across GPUs, training starts and loss decreases.
However, when passing --mixed_precision=fp8 I'm getting following error:
<omitting some client specific code >
t = Trainer(
File "/layers/pip/requirements/lib/python3.10/site-packages/transformers/trainer.py", line 347, in __init__
self.create_accelerator_and_postprocess()
File "/layers/pip/requirements/lib/python3.10/site-packages/transformers/trainer.py", line 3940, in create_accelerator_and_postprocess
self.accelerator = Accelerator(
File "/layers/pip/requirements/lib/python3.10/site-packages/accelerate/accelerator.py", line 365, in __init__
self.state = AcceleratorState(
File "pip/requirements/lib/python3.10/site-packages/accelerate/state.py", line 765, in __init__
fsdp_plugin.set_mixed_precision(self._mixed_precision)
File "/layers/pip/requirements/lib/python3.10/site-packages/accelerate/utils/dataclasses.py", line 979, in set_mixed_precision
raise ValueError(f"Unknown mixed precision value: {mixed_precision}")
ValueError: Unknown mixed precision value: fp8
Looking into stacktrace I can see that while accelerate CLI supports --mixed_precision=fp8 (reference) FSDP plugin seems to only support "no", "fp16" or "bf16" options (reference)
Can you please confirm that my understanding is correct, that Accelerate supports FP8 only withoud Zero-3 sharding frameworks (e.g. FSDP or DeepSpeed). If my understanding is correct, does Accelerate team have a timeline to add FP8 support to FSDP Plugin?
Expected behavior
I expect that both bf16 and fp8 to work similarly.
cc @pacman100
FSDP support for fp8 is experimental and is on NVIDIA's roadmap (with currently no public prototype yet). We need to wait on them.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.