DeepSpeed
DeepSpeed copied to clipboard
When using pure DeepSpeed ulysses and zero stage 3 to continue pre-training, the loss gap between each GPU is too large.[BUG]
Describe the bug When I use pure deepspeed ulysee and zero 3 to continue pre-training, the loss gap between each GPU is too large. The loss of GPU0 is around 1.5-2 ,and the loss of GPU1-3 is around 6-7.
GPU: 4*A100 (80G) Model: Qwen-7B-base
To Reproduce I use pure deepspeed for training, training code reference pre_train-Baichuan In order to use deepspeed sequence parallelism, I made the following modifications:
- Modify attention Here we use the Attention of FlashAttention2 to replace the original
def forward(self, ...):
...
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
...
change into
def __init__(self, ...):
...
self.dist_ulysser_attn = UlyssesAttention()
...
def forward(self, ...):
...
attn_output = self.dist_ulysser_attn(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
...
UlyssesAttention()
here is implemented through the DeepSpeed-Ulysses technique mentioned above
2. Define Group
_SEQUENCE_PARALLEL_GROUP = None
def initialize_model_parallel(
sequence_parallel_size,
):
world_size = dist.get_world_size()
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
global _SEQUENCE_PARALLEL_GROUP
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
group = torch.distributed.new_group(ranks)
rank = dist.get_rank()
if rank in ranks:
_SEQUENCE_PARALLEL_GROUP = group
def get_sequence_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
return _SEQUENCE_PARALLEL_GROUP
- Modify the training incoming data - do sequence length segmentation
...
# get sequence parallel sub_seq_start and sub_seq_end
seq_parallel_rank = get_sequence_parallel_rank()
seq_parallel_world_size = get_sequence_parallel_world_size()
seq_length = self.max_length // seq_parallel_world_size
self.sub_seq_start = seq_parallel_rank * seq_length
self.sub_seq_end = (seq_parallel_rank + 1) * seq_length
...
while True:
model_engine.train()
step = 0
while step < args.steps_per_epoch:
# to split data
data = next(data_iter)[:, data_engine.sub_seq_start:data_engine.sub_seq_end]
loss = model_engine(data, labels=data).loss
model_engine.backward(loss)
model_engine.step()
step += 1
epoch += 1
new_steps = args.laststep + epoch*args.steps_per_epoch
model_engine.save_checkpoint(f"{args.checkpoint_saving_path}",
tag=f"checkpoint-{new_steps}")
- Check the loss in the log
2024-04-16 17:47:15,911 : ******loss****** step:1 loss:6.658958435058594 gpu:1
2024-04-16 17:47:15,912 : ******loss****** step:1 loss:6.659845352172852 gpu:3
2024-04-16 17:47:15,915 : ******loss****** step:1 loss:2.49433970451355 gpu:0
2024-04-16 17:47:15,914 : ******loss****** step:1 loss:6.686857223510742 gpu:2
2024-04-16 17:47:38,869 : ******loss****** step:2 loss:7.213046073913574 gpu:1
2024-04-16 17:47:38,870 : ******loss****** step:2 loss:7.709272861480713 gpu:3
2024-04-16 17:47:38,870 : ******loss****** step:2 loss:6.571146488189697 gpu:2
2024-04-16 17:47:38,871 : ******loss****** step:2 loss:2.100698947906494 gpu:0
2024-04-16 17:47:59,575 : ******loss****** step:3 loss:6.615096092224121 gpu:2
2024-04-16 17:47:59,575 : ******loss****** step:3 loss:5.659509181976318 gpu:1
2024-04-16 17:47:59,576 : ******loss****** step:3 loss:2.1568846702575684 gpu:0
2024-04-16 17:47:59,577 : ******loss****** step:3 loss:6.5622239112854 gpu:3
2024-04-16 17:48:20,795 : ******loss****** step:4 loss:4.411237716674805 gpu:1
2024-04-16 17:48:20,795 : ******loss****** step:4 loss:6.673219203948975 gpu:2
2024-04-16 17:48:20,796 : ******loss****** step:4 loss:6.635258197784424 gpu:3
2024-04-16 17:48:20,797 : ******loss****** step:4 loss:2.3108065128326416 gpu:0
2024-04-16 17:48:41,274 : ******loss****** step:5 loss:6.497841835021973 gpu:1
2024-04-16 17:48:41,275 : ******loss****** step:5 loss:6.800827503204346 gpu:2
2024-04-16 17:48:41,275 : ******loss****** step:5 loss:8.080429077148438 gpu:3
2024-04-16 17:48:41,276 : ******loss****** step:5 loss:2.609225273132324 gpu:0
We can clearly see that the loss of GPU0 is smaller than the loss of other GPUs
Expected behavior The loss of each GPU should be around 2 (because this is continuing pre-training)
ds_report output
Screenshots No need to take screenshots.
System info (please complete the following information):
- OS: [Ubuntu 22.04]
- GPU count and types [one machine with x4 A100s ]
- Python version [3.10.14]
Launcher context My deepspeed config:
{
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 1,
"prescale_gradients": false,
"zero_allow_untested_optimizer": true,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-8,
"eps": 1.0e-8,
"betas": [
0.9,
0.95
],
"weight_decay": 0.1
}
},
"tensorboard": {
"enabled": true,
"output_path": "logs/",
"job_name": "qwen-7b-pt"
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
},
"steps_per_print": 16,
"gradient_clipping": 1.0,
"wall_clock_breakdown": true,
"bf16": {
"enabled": true
}
}
I have three questions about pre-training by prue deepspeed and deepspeed-Ulysses:
- Is there something wrong with the processing of my pre-training code?
- I have always wondered whether I need to merge output (N/P) into output (N) after calculating attention and then perform subsequent calculation of loss, but I feel that this is not appropriate.
- For pure deepspeed, when
model.backward(loss)
is finally executed, Allreduce is performed on the loss of each rank. If the output is not merged here, the loss obtained by calculating the N/P sequence length for each rank is directly modeled. Ismodel.backward(loss)
logical?
Would you mind taking care of this issue, thanks! @samadejacobs , @RezaYazdaniAminabadi
@Kwen-Chen, your input data processing looks good to me. As for your second and third questions, you need a sequence- parallel-aware loss calculation (see example here).
@Kwen-Chen, your input data processing looks good to me. As for your second and third questions, you need a sequence- parallel-aware loss calculation (see example here).
thanks for your replay ! But i want to know, when I process input data like that, and use sequence-parallel-aware loss calculation (this), a problem that I can't do model.backward(loss)
because the model (N/P) does not correspond to the loss (N) arises
It doesn't see that we need the sequence parallel-aware loss function according to this issue though:
https://github.com/microsoft/DeepSpeed/issues/5248
It seems that this has been handled implicitly by Deepspeed Ulysses, right?
@samadejacobs