DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] deepspeed overlap_comm data race

Open yangyihang-bytedance opened this issue 1 year ago • 2 comments

Describe the bug As illustrated below,DeepSpeed's overlap buffer design presents potential data race. I have write a patch for bugfix.

Could you kindly help diagnosing and fix this issue?

whiteboard_exported_image

To Reproduce

debug.py

import argparse
from this import d
import deepspeed.runtime.zero.stage_1_and_2
import torch
import torch.nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import set_seed
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.moe.utils import is_moe_param

def patch_deepspeed():
    def backward(self, loss, retain_graph=False):
        """
        :attr:`backward` performs the following steps:

        1. fp32_loss = loss.float()
        2. scaled_loss = fp32_loss*loss_scale
        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
        """
        self.micro_step_id += 1

        if self.contiguous_gradients:
            self.ipg_buffer = []
            self.ipg_events = []  # 添加的代码,yihang: 缓冲区 event
            buf_0 = torch.empty(int(self.reduce_bucket_size),
                                dtype=self.dtype,
                                device=get_accelerator().current_device_name())
            self.ipg_buffer.append(buf_0)
            self.ipg_events.append(None)  # 添加的代码,yihang: 缓冲区 event

            # Use double buffers to avoid data access conflict when overlap_comm is enabled.
            if self.overlap_comm:
                buf_1 = torch.empty(int(self.reduce_bucket_size),
                                    dtype=self.dtype,
                                    device=get_accelerator().current_device_name())
                self.ipg_buffer.append(buf_1)
                self.ipg_events.append(None)   # 添加的代码,yihang: 缓冲区 event

            self.ipg_index = 0

        if self.custom_loss_scaler:
            scaled_loss = self.external_loss_scale * loss
            scaled_loss.backward()
        else:
            self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)

        # Only for Stage 1, Mode 2
        if self.use_grad_accum_attribute:
            self.fill_grad_accum_attribute()

    def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):

        grad_reduc = self.get_gradient_for_reduction(param)
        if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
            self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
            self.reduce_ipg_grads()

            # 添加的代码,yihang: 缓冲区 event
            if self.contiguous_gradients and self.overlap_comm and not get_accelerator().is_synchronized_device():
                with get_accelerator().stream(self.reduction_stream):
                    current_event = torch.cuda.Event() # yihang: 在 reduction_stream 上创建一个事件
                    current_event.record(self.reduction_stream)

                    self.ipg_events[self.ipg_index] = current_event

            if self.contiguous_gradients and self.overlap_comm:
                # Swap ipg_index between 0 and 1
                self.ipg_index = 1 - self.ipg_index

            # 添加的代码,yihang: 缓冲区 event
            prev_event = self.ipg_events[self.ipg_index]
            if prev_event is not None:
                prev_event.wait()

            self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())

        param_id = self.get_param_id(param)
        assert self.params_already_reduced[param_id] == False, \
            f"The parameter {param_id} has already been reduced. \
            Gradient computed twice for this partition. \
            Multiple gradient reduction is currently not supported"

        if self.contiguous_gradients:
            if param.numel() > self.reduce_bucket_size:
                self.extra_large_param_to_reduce = param
            else:
                # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
                new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
                new_grad_tensor.copy_(grad_reduc.view(-1))
                grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)

        self.elements_in_ipg_bucket += param.numel()

        assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

        self.grads_in_ipg_bucket.append(grad_reduc)
        self.params_in_ipg_bucket.append((i, param, param_id))

        #make sure the average tensor function knows how to average the gradients
        if is_moe_param(param):
            self.ipg_bucket_has_moe_params = True

        self.report_ipg_memory_usage("End ipg_remove_grads", 0)

    deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.backward = backward
    deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.reduce_independent_p_g_buckets_and_remove_grads = reduce_independent_p_g_buckets_and_remove_grads

class DummyModule(torch.nn.Module):
    def __init__(self, n_layer=16, hidden_size=1024, vocab_size=1024) -> None:
        super(DummyModule, self).__init__()

        self._vocab_size = vocab_size

        self.embs = torch.nn.ModuleList([
            torch.nn.Embedding(vocab_size, hidden_size) for _ in range(n_layer)
        ])

    def init_weights(self):
        for emb in self.embs:
            emb.weight.data.normal_(mean=0.0, std=0.0002)
            if emb.padding_idx is not None:
                emb.weight.data[emb.padding_idx].zero_()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        idx = torch.range(0, self._vocab_size - 1, device=input.device, dtype=torch.long)

        hidden_states = input
        for emb in self.embs:
            hidden_states = hidden_states + emb(idx)

        return hidden_states.mean()

def train():
    torch.use_deterministic_algorithms(True)

    accelerator = Accelerator(project_dir="./outputs")

    device = accelerator.device
    dtype = torch.float16

    set_seed(42 + accelerator.process_index)

    model = DummyModule()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model, optimizer = accelerator.prepare(model, optimizer)

    for global_step in range(20):
        input = torch.rand(1, dtype=dtype, device=device)
        label = torch.rand(1, dtype=dtype, device=device)

        optimizer.zero_grad()
        output = model(input)
        
        loss = F.mse_loss(output.float(), label.float())

        accelerator.backward(loss)
        optimizer.step()

        global_grad_norm = -100.0
        if hasattr(optimizer.optimizer, '_global_grad_norm'):
            global_grad_norm = optimizer.optimizer._global_grad_norm

        if accelerator.process_index == 0:
            print(f'rank [{accelerator.process_index}] global_step [{global_step}] loss [{loss.item():.9}] grad_norm [{global_grad_norm:.9}]')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--patch_deepspeed", action='store_true', help="patch deepspeed")
    args = parser.parse_args()

    if args.patch_deepspeed:
        patch_deepspeed()

    train()

if __name__ == '__main__':
    main()

Unexpected behavior

System info (please complete the following information):

  • GPU count and types [one machines with x2 A100s each]

Launcher context

dp_zero1_fp16.yaml

{
    "train_micro_batch_size_per_gpu": 6,
    "steps_per_print": 100,
    "prescale_gradients": false,
    "zero_allow_untested_optimizer": true,
    "gradient_accumulation_steps": "auto",
    "bf16": {
        "enabled": false
    },
    "fp16": {
        "enabled": true
    },
    "wall_clock_breakdown": false,
    "gradient_clipping": 1.0,
    "zero_optimization": {
        "stage": 1,
        "allgather_partitions": true,
        "reduce_scatter": true,
        "allgather_bucket_size": 1e8,
        "reduce_bucket_size": 1048576,
        "stage3_max_reuse_distance": 2e9,
        "overlap_comm": true,
        "contiguous_gradients": true
    }
}

debug.sh

#!/bin/bash

set -ex

num_processes=2

export TORCH_CUDA_SANITIZER=1

accelerate launch --main_process_ip $main_host --main_process_port $main_port \
    --num_machines 1 --machine_rank 0 --num_processes $num_processes \
    --use_deepspeed --deepspeed_config_file dp_zero1_fp16.yaml --deepspeed_multinode_launcher standard \
    debug.py

Additional context A similar issue is present in stage3 as well, yet I have not prepared the patch for it.

yangyihang-bytedance avatar May 18 '24 06:05 yangyihang-bytedance

Additional context

deepspeed==0.14.2
CSAN detected a possible data race on tensor with data pointer 140108946735104
Access by stream 0 during kernel:
aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
writing to argument(s) self, and to the output
With stack trace:
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 949, in reduce_independent_p_g_buckets_and_remove_grads
    new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 570, in __torch_dispatch__
    errors = self.event_handler._handle_kernel_launch(
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 371, in _handle_kernel_launch
    stack_trace = traceback.StackSummary.extract(

Previous access by stream 152815408 during kernel:
aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)
writing to argument(s) self, and to the output
With stack trace:
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 903, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1416, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 932, in reduce_independent_p_g_buckets_and_remove_grads
    self.reduce_ipg_grads()
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1367, in reduce_ipg_grads
    self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket))
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1127, in average_tensor
    self.allreduce_and_scatter(buckets[bucket_key],
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1031, in allreduce_and_scatter
    self.allreduce_and_copy_with_multiple_ranks(small_bucket,
  File "/usr/local/lib/python3.9/dist-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1005, in allreduce_and_copy_with_multiple_ranks
    for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks):
  File "/usr/local/lib/python3.9/dist-packages/torch/_utils.py", line 534, in _unflatten_dense_tensors
    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 570, in __torch_dispatch__
    errors = self.event_handler._handle_kernel_launch(
  File "/usr/local/lib/python3.9/dist-packages/torch/cuda/_sanitizer.py", line 371, in _handle_kernel_launch
    stack_trace = traceback.StackSummary.extract(

yangyihang-bytedance avatar May 18 '24 06:05 yangyihang-bytedance

@yangyihang-bytedance, can you please confirm if this was fixed by #5606? Thanks!

tjruwase avatar Aug 03 '24 17:08 tjruwase

This issue is not fixed yet.

wizyoung avatar Dec 11 '24 02:12 wizyoung