apply reduce_scatter_coalesced op
save time and memory overhead in maintaining flattened buffers.
@inkcherry, thanks for this PR. Are you able to provide some observed memory and latency benefits?
@inkcherry, thanks for this PR. Are you able to provide some observed memory and latency benefits?
Hi, @tjruwase , I use a setup of 4xA800 80G with PyTorch version 2.2. I utilize the script available at https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/ds_train_finetune.sh and configure zero_stage to 3.
Average performance: 1.71it/s -> 1.77 it/s (The function here has approximately a 20% performance improvement). When the seq_length is set to 1k, the average performance improves from 1.42 s/it to 1.39 s/it.
Similar to the latter, for some other models, the end-to-end improvement might be not significant because the proportion here (preparing flat data for reduce_scatter) is relatively low. Similarly, it seems that __avg_scatter_contiguous_grads also avoids the preparation work for this data through allreduce.
Regarding memory, I mean it may avoid some allocations. https://github.com/microsoft/DeepSpeed/blob/4829bdd78ac333273b4aa4ef13d5881af1b8ac51/deepspeed/runtime/comm/coalesced_collectives.py#L129
Hi @inkcherry ,
Thx for the PR. Based on my understanding, this PR's reduce-scatter logic is identical to what we already have. I also tested on an 8*A100 DGX node. But unfortunately, I did not see any e2e speedup.
Regarding to memory save, your PR also did torch.cat as following:
flattened_tensor = torch.cat([
flattened_tensor,
torch.empty(padding_size_list[tensor_idx],
dtype=flattened_tensor.dtype,
device=flattened_tensor.device)
])
Therefore I don't see memory saving either.
Let me know if there is any misunderstanding on my side.