DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Zero2 offload overflow

Open nickyoungforu opened this issue 1 year ago • 21 comments

Describe the bug I have been able to run my model successfully in Zero Stage 3 without any problems. However, when I attempt to run the same model in Zero Stage 2, I encounter an error: 10.223.17.15: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.223.13.141: Traceback (most recent call last): 10.78.107.139: self._take_model_step(lr_kwargs) 10.67.196.141: self.deepspeed_engine_wrapped.backward(loss, **kwargs) 10.78.121.13: 10.223.17.15: main() 10.223.13.141: File "train.py", line 706, in 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.67.196.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/accelerate/utils/deepspeed.py", line 176, in backward 10.78.121.13: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.223.17.15: File "train.py", line 587, in main 10.67.196.141: self._take_model_step(lr_kwargs) 10.78.107.139: self.deepspeed_engine_wrapped.backward(loss, **kwargs) 10.223.13.141: scaled_global_grad_norm = self.scaled_global_norm() 10.78.121.13: self._take_model_step(lr_kwargs) 10.223.17.15: return torch.norm(torch.stack(norm_groups), p=norm_type) 10.67.196.141: self.optimizer.step() 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/accelerate/utils/deepspeed.py", line 176, in backward 10.223.13.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1789, in scaled_global_norm 10.223.13.141: self.optimizer.step() 10.78.121.13: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.223.17.15: File "/opt/conda/envs/progen/lib/python3.8/site-packages/torch/functional.py", line 1626, in norm 10.78.107.139: self.optimizer.step()self.optimizer.step() 10.67.196.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.223.13.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.223.17.15: self._take_model_step(lr_kwargs) 10.78.107.139: 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.67.196.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.78.121.13: self.optimizer.step()Traceback (most recent call last): 10.223.13.141: main() 10.223.17.15: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.67.196.141: main() 10.223.17.15: accelerator.backward(loss) 10.78.107.139: self.engine.step() 10.223.13.141: File "train.py", line 587, in main 10.78.121.13: 10.67.196.141: File "train.py", line 587, in main 10.223.17.15: File "/opt/conda/envs/progen/lib/python3.8/site-packages/accelerate/accelerator.py", line 1960, in backward 10.223.17.15: main() 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2169, in step 10.223.13.141: self.optimizer.step()self._take_model_step(lr_kwargs) 10.78.121.13: self.optimizer.step() File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.78.121.13: 10.67.196.141: self.engine.step() 10.223.17.15: File "train.py", line 587, in main 10.78.107.139: self._take_model_step(lr_kwargs) 10.223.13.141: 10.78.121.13: File "train.py", line 706, in 10.67.196.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2169, in step 10.223.17.15: scaled_global_grad_norm = self.scaled_global_norm() 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.223.13.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.78.121.13: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.223.13.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.78.107.139: main() 10.67.196.141: scaled_global_grad_norm = self.scaled_global_norm() 10.223.17.15: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1789, in scaled_global_norm 10.78.121.13: main()
10.223.13.141: self._take_model_step(lr_kwargs) 10.78.107.139: File "train.py", line 587, in main 10.67.196.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1789, in scaled_global_norm 10.223.17.15: return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype) 10.78.121.13: self.optimizer.step() 10.223.13.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step 10.78.107.139: Traceback (most recent call last): 10.223.13.141: accelerator.backward(loss) 10.223.17.15: RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got Long 10.78.121.13: File "train.py", line 587, in main 10.223.17.15: self.optimizer.step() 10.67.196.141: scaled_global_grad_norm = self.scaled_global_norm() 10.78.107.139: scaled_global_grad_norm = self.scaled_global_norm() 10.223.13.141: self._take_model_step(lr_kwargs) 10.78.121.13: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.78.121.13: self.optimizer.step() 10.78.107.139: File "train.py", line 706, in 10.223.17.15: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.223.13.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/accelerate/accelerator.py", line 1960, in backward 10.67.196.141: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1789, in scaled_global_norm 10.78.121.13: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1842, in step 10.78.107.139: File "/opt/conda/envs/progen/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1789, in scaled_global_norm

To Reproduce I am using a cluster of 6 machines, each equipped with 40GB A100 GPUs.

launch cmd: accelerate launch --config_file config/yaml/zero.yaml train.py --with_tracking --report_to tensorboard --output_dir tblog --project_name test --checkpointing_steps epoch --per_device_train_batch_size 8 --num_train_epochs 50 --learning_rate 5e-5 --seed 42 --precision bf16

Here is my yaml: compute_environment: LOCAL_MACHINE deepspeed_config: deepspeed_config_file: ./config/zero/zero_stage2_bf16.json deepspeed_multinode_launcher: pdsh deepspeed_hostfile: ./config/yaml/hostfile zero3_init_flag: true distributed_type: DEEPSPEED fsdp_config: {} machine_rank: 0 main_process_ip: 10.223.17.15 main_process_port: 36769 main_training_function: main num_machines: 6 num_processes: 48 use_cpu: false

zero_stage2_bf16.json: { "bf16": { "enabled": true }, "optimizer": { "type": "AdamW", "params": { "lr": "auto", "weight_decay": "auto", "torch_adam": true, "adam_w_mode": true } }, "scheduler": { "type": "WarmupDecayLR", "params": { "warmup_min_lr": "auto", "warmup_max_lr": "auto", "warmup_num_steps": "auto", "total_num_steps": 533770 } }, "zero_optimization": { "stage": 2, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "allgather_partitions": true, "allgather_bucket_size": 2e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": "auto", "contiguous_gradients": true }, "gradient_accumulation_steps": 1, "gradient_clipping": "auto", "steps_per_print": 2000, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false }

nickyoungforu avatar Mar 08 '24 06:03 nickyoungforu

Hi, can you please add a title?

loadams avatar Mar 08 '24 16:03 loadams

I've encountered this bug too. After inspection, it feels to me the following implementation in the current stable release is related:

#In deepspeed/runtime/zero/stage_1_and_2.py

    def complete_grad_norm_calculation_for_cpu_offload(self, params):
        total_norm = 0.0
        norm_type = 2.0
        for p in params:
            # Pipeline parallelism may replicate parameters. Avoid multi-counting.
            if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
                continue

            if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
                param_id = self.get_param_id(p)
                # as some model have trainable parameters but skipped in training,
                # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
                # so they have no norm_for_param_grads
                if param_id in self.norm_for_param_grads:
                    param_norm = self.norm_for_param_grads[param_id]
                    total_norm += param_norm.item()**2
                else:
                    # As unused parameters in modules may not be expected sometimes,
                    # add an explicit error msg when it occurred and an option to
                    # avoid the error
                    assert self.ignore_unused_parameters, """
                        This assert indicates that your module has parameters that
                        were not used in producing loss.
                        You can avoid this assert by
                        (1) enable ignore_unused_parameters option in zero_optimization config;
                        (2) making sure all trainable parameters and `forward` function
                            outputs participate in calculating loss.
                    """

        # Sum across all model parallel GPUs.
        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
        dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)

        self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)

        total_norm = total_norm_cuda[0].item()**(1. / norm_type)

        if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
            total_norm = -1

        return total_norm

Here, when gradient or its norm overflows, it seems the current implementation will return an integer "-1", which should be handled afterwards by special judgements. However, this is NEVER correctly handled elsewhere, cause the reported bug. When this bug happens, the current iteration of training is already non-recoverably broken. So while we should still break the training loop here, I'd suggest an exception with necessary information (e.g. checking the numerical stability of the algorithm, etc) for the user to debug instead of the current uninformative log info.

desire2020 avatar Mar 11 '24 05:03 desire2020

Any progress? I encountered the same bug when using zero2-offload, while zero3-offload works correctly.

yiranma0 avatar May 17 '24 20:05 yiranma0

I would also love an update on this, zero2 offloading doesn't work for me as well, with the same error.

ShuaHousetable avatar Jun 21 '24 14:06 ShuaHousetable

same here, zero2, the -1 gets continued passed to next computation step which expects float tensor, which causes the error

disperaller avatar Jun 25 '24 06:06 disperaller

same here, zero2, the -1 gets continued passed to next computation step which expects float tensor, which causes the error

set overlap_comm to False resolve the issue for me.

disperaller avatar Jun 25 '24 07:06 disperaller

I solved this issue by changing the model loading from model = AutoModelForCausalLM.from_pretrained( "/tmp/customer/duanzhichao/models/Qwen2-7B-Instruct", torch_dtype="auto", ) to model = AutoModelForCausalLM.from_pretrained( "/tmp/customer/duanzhichao/models/Qwen2-7B-Instruct" )

ZhichaoDuan avatar Jul 01 '24 07:07 ZhichaoDuan

@nickyoungforu are you still hitting this?

loadams avatar Jul 09 '24 21:07 loadams

Switching to precision=32 fixed this for me. I see you had "--precision bf16" in your code maybe that was the problem. I had the same error when my precision was 16.

desire2020 is right about the cause of the error, I added logging in that function and see param_norm going to NaN for all param_ids and GPUs at the same time when this happens.

AceMcAwesome77 avatar Aug 21 '24 16:08 AceMcAwesome77

I met the same problem, and due to the limit of cuda memory, I cannot set precision to float32.

Kamichanw avatar Sep 12 '24 02:09 Kamichanw

I solved this issue by changing the model loading from model = AutoModelForCausalLM.from_pretrained( "/tmp/customer/duanzhichao/models/Qwen2-7B-Instruct", torch_dtype="auto", ) to model = AutoModelForCausalLM.from_pretrained( "/tmp/customer/duanzhichao/models/Qwen2-7B-Instruct" )

@tjruwase I met the same problem when using "torch_dtype=torch.bfloat16". And all the methods mentioned above are not working. For all tensors in all the 8 threads, param_norm sudddenly change to NAN at training step 10.

QingtaoLi1 avatar Dec 13 '24 09:12 QingtaoLi1

@QingtaoLi1, are you able to provide full repro steps?

tjruwase avatar Dec 23 '24 23:12 tjruwase

@tjruwase I found a solution in this unmerged PR. And also I change torch.bfloat16 to torch.float32, although I think this is not necessary to fix this issue. But I still wonder why the grad norms can overflow since bf16 and fp32 have the same 3e38 range.

And one more hint for this issue to be a bug: I'm using 4 GPUs to do my training, each is assigned a mini-batch different from the other three, but everytime the overflow happens, all the 4 workers overflow. I don't think this is a normal situation.

QingtaoLi1 avatar Jan 07 '25 08:01 QingtaoLi1

three "solutions" work for my case:

  1. use zero2 + bf16, instead of zero2 offload + bf16;
  2. use fp16 than bf16 (works for zero2 offload);
  3. change the source code that @QingtaoLi1 mentioned.

Smu-Tan avatar Jan 22 '25 20:01 Smu-Tan

One quick fix, worked in my case setting overlap_comm to false

SwayamInSync avatar Jan 25 '25 14:01 SwayamInSync

@SwayamInSync, @Smu-Tan, @QingtaoLi1, @Kamichanw, @AceMcAwesome77, @desire2020 please try #6976

tjruwase avatar Jan 30 '25 15:01 tjruwase

One quick fix, worked in my case setting overlap_comm to false

@SwayamInSync, can you please share your repro to help us debug why overlap_comm is triggering this issue? Thanks!

tjruwase avatar Jan 30 '25 15:01 tjruwase

other three, but everytime the overflow happens, all the 4 workers overflow. I don't think this is a normal situation. @QingtaoLi1, this is probably because overflow checking is based on the reduced gradients, and so the bad gradients from one worker would already have propagated to other workers.

tjruwase avatar Feb 08 '25 08:02 tjruwase