MS-AMP icon indicating copy to clipboard operation
MS-AMP copied to clipboard

MS-AMP crashes with DeepSpeed ZeRO 3

Open rationalism opened this issue 2 years ago • 3 comments

I am fine-tuning Facebook's OPT-1.3B on 2x 4090 GPUs, using Ubuntu 22.04, PyTorch 2.1.0, CUDA 12.1, and HuggingFace Accelerate, using this code from the HuggingFace examples repo:

https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py

When using DeepSpeed ZeRO 3 for partitioning model sizes with optimization level O3, MS-AMP crashes with this stack trace:

Traceback (most recent call last): File "/home/alyssa/lm_fun/run_clm.py", line 769, in main() File "/home/alyssa/lm_fun/run_clm.py", line 583, in main model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/accelerator.py", line 1284, in prepare result = self._prepare_deepspeed(*args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/accelerator.py", line 1667, in _prepare_deepspeed engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/deepspeed/init.py", line 119, in initialize config_class = MSAMPDeepSpeedConfig(config, mpu) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 777, in init self._do_sanity_check() File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 957, in _do_sanity_check self._do_error_check() File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/deepspeed/runtime/config.py", line 40, in _do_error_check self.zero_optimization_stage in [ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients],
AssertionError: MS-AMP O3 requires ZeRO with optimizer_states or gradients partitioning.

When I switch to optimization level O2, it instead crashes with this stack trace, presumably because the MS-AMP cast.py code doesn't expect DeepSpeed's parameter partitioning:

Traceback (most recent call last): File "/home/alyssa/lm_fun/run_clm.py", line 769, in main() File "/home/alyssa/lm_fun/run_clm.py", line 583, in main model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/accelerator.py", line 1284, in prepare result = self._prepare_deepspeed(*args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/accelerator.py", line 1667, in _prepare_deepspeed engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/deepspeed/init.py", line 135, in initialize engine = MSAMPDeepSpeedEngine( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 304, in init self._configure_optimizer(optimizer, model_parameters) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/deepspeed/runtime/engine.py", line 81, in _configure_optimizer model, basic_optimizer = msamp_initialize(self.module, basic_optimizer, optlevel) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/init.py", line 61, in initialize cast_model = LinearReplacer.replace(model) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 176, in replace Traceback (most recent call last): File "/home/alyssa/lm_fun/run_clm.py", line 769, in model = cls._replace(model, weight_qtype) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 158, in _replace setattr(model, child_name, cls._replace(child, weight_qtype)) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 158, in _replace setattr(model, child_name, cls._replace(child, weight_qtype)) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 158, in _replace setattr(model, child_name, cls._replace(child, weight_qtype)) [Previous line repeated 3 more times] File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 154, in _replace main() File "/home/alyssa/lm_fun/run_clm.py", line 583, in main fp8_net = cls._build_fp8linear(model, weight_qtype) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/accelerator.py", line 1284, in prepare File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 98, in _build_fp8linear weight = weight.cast(weight_qtype)
result = self._prepare_deepspeed(*args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/common/tensor/tensor.py", line 703, in _cast_to_scalingtensor

File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/accelerator.py", line 1667, in _prepare_deepspeed return ScalingTensor(TypeCast.cast_to_fp16(self, meta, sync=sync), meta=meta) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/common/tensor/cast.py", line 81, in cast_to_fp16 meta.amax[0] = input.abs().max() RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument. engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/deepspeed/init.py", line 135, in initialize engine = MSAMPDeepSpeedEngine( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 304, in init self._configure_optimizer(optimizer, model_parameters) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/deepspeed/runtime/engine.py", line 81, in _configure_optimizer model, basic_optimizer = msamp_initialize(self.module, basic_optimizer, optlevel) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/init.py", line 61, in initialize cast_model = LinearReplacer.replace(model) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 176, in replace model = cls._replace(model, weight_qtype) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 158, in _replace setattr(model, child_name, cls._replace(child, weight_qtype)) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 158, in _replace setattr(model, child_name, cls._replace(child, weight_qtype)) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 158, in _replace setattr(model, child_name, cls._replace(child, weight_qtype)) [Previous line repeated 3 more times] File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 154, in _replace fp8_net = cls._build_fp8linear(model, weight_qtype) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/nn/linear.py", line 98, in _build_fp8linear weight = weight.cast(weight_qtype) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/common/tensor/tensor.py", line 703, in _cast_to_scalingtensor return ScalingTensor(TypeCast.cast_to_fp16(self, meta, sync=sync), meta=meta) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/msamp/common/tensor/cast.py", line 81, in cast_to_fp16 meta.amax[0] = input.abs().max() RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument. [2023-11-14 17:42:01,194] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 543272) of binary: /home/alyssa/anaconda3/envs/lm_fun/bin/python3 Traceback (most recent call last): File "/home/alyssa/anaconda3/envs/lm_fun/bin/accelerate", line 8, in sys.exit(main()) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main args.func(args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/commands/launch.py", line 979, in launch_command deepspeed_launcher(args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/commands/launch.py", line 695, in deepspeed_launcher distrib_run.run(args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run elastic_launch( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

rationalism avatar Nov 14 '23 22:11 rationalism

Hi @rationalism , thanks for your attention to our work! We did not implement DeepSpeed ZeRO 3 with MS-AMP support.

ZeRO 1 and ZeRO 2 with MS-AMP support are available.

wkcn avatar Nov 15 '23 01:11 wkcn

@wkcn Will deepspeed ZeRO 3 be supported in the future? I saw that FSDP will be supported.

skyshine102 avatar Nov 15 '23 16:11 skyshine102

Note: DeepSpeed ZeRO 3 works for me if you just use the low-bit optimizer (LBAdamW) in place of Adam, rather than using MS-AMP as an integrated framework.

rationalism avatar Nov 24 '23 01:11 rationalism

Close this issue since there is no activity for a long time.

tocean avatar Aug 13 '24 08:08 tocean