DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

apply reduce_scatter_coalesced op

Open inkcherry opened this issue 1 year ago • 2 comments

save time and memory overhead in maintaining flattened buffers.

inkcherry avatar Mar 04 '24 12:03 inkcherry

@inkcherry, thanks for this PR. Are you able to provide some observed memory and latency benefits?

tjruwase avatar Apr 05 '24 02:04 tjruwase

@inkcherry, thanks for this PR. Are you able to provide some observed memory and latency benefits?

Hi, @tjruwase , I use a setup of 4xA800 80G with PyTorch version 2.2. I utilize the script available at https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/ds_train_finetune.sh and configure zero_stage to 3.

Average performance: 1.71it/s -> 1.77 it/s (The function here has approximately a 20% performance improvement). When the seq_length is set to 1k, the average performance improves from 1.42 s/it to 1.39 s/it.

Similar to the latter, for some other models, the end-to-end improvement might be not significant because the proportion here (preparing flat data for reduce_scatter) is relatively low. Similarly, it seems that __avg_scatter_contiguous_grads also avoids the preparation work for this data through allreduce.

Regarding memory, I mean it may avoid some allocations. https://github.com/microsoft/DeepSpeed/blob/4829bdd78ac333273b4aa4ef13d5881af1b8ac51/deepspeed/runtime/comm/coalesced_collectives.py#L129

inkcherry avatar Apr 08 '24 15:04 inkcherry

Hi @inkcherry ,

Thx for the PR. Based on my understanding, this PR's reduce-scatter logic is identical to what we already have. I also tested on an 8*A100 DGX node. But unfortunately, I did not see any e2e speedup.

Regarding to memory save, your PR also did torch.cat as following:

                flattened_tensor = torch.cat([
                    flattened_tensor,
                    torch.empty(padding_size_list[tensor_idx],
                                dtype=flattened_tensor.dtype,
                                device=flattened_tensor.device)
                ])

Therefore I don't see memory saving either.

Let me know if there is any misunderstanding on my side.

GuanhuaWang avatar Jul 15 '24 20:07 GuanhuaWang