Support deepspeed sequence parallel
What does this PR do?
Support the sequence parallel with Deepspeed-Ulysses.
I have tested the training on starcoder2-3b. The loss decreases normally.
Requires https://github.com/huggingface/accelerate/pull/2877
~~I have made massive modifications to the original implementation of Deepspeed-Ulysses to support batch size dim in layers.py. It uses all_to_all_single instead of all_to_all like https://github.com/InternLM/InternEvo/blob/a61d391df96c5f5c243cdea32a5044b70d6fe33e/internlm/core/parallel/comm/isp.py#L628 for better performance. I have left some comments to help the future understanding.~~ Use all_to_all_single is too complex to support other scatter idx and gather idx
Currently, flash attn and sdpa for llama and mistral are tested. flash attn for starcoder is also tested, the sdpa for starcoder is not supported.
It requires a special dataloader (I have made in Trainer) and data collator (with example followed). In data collator, the sequence should be divided into multiple sub-sequences. The following is an example of sub-sequences processing in the data collator.
seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
seq_parallel_world_rank = mpu.get_sequence_parallel_rank()
seq_length = input_ids.size(1)
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_world_rank * sub_seq_length
sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length
# There is no kv cache when training
past_key_values_length = 0
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
batch = dict(
input_ids=input_ids[:, sub_seq_start:sub_seq_end],
labels=labels[:, sub_seq_start:sub_seq_end],
position_ids=position_ids[:, sub_seq_start:sub_seq_end],
attention_mask=(input_ids != self.tokenizer.pad_token_id),
)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@muellerzr and @SunMarc
Great, can you provide an example of data processing based on sequence paralleler? thanks
The dataset and sampler are handled in the Trainer
https://github.com/huggingface/transformers/pull/31525/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR847-R855
The data collator example is accidentally deleted when editing
seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
seq_parallel_world_rank = mpu.get_sequence_parallel_rank()
seq_length = input_ids.size(1)
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_world_rank * sub_seq_length
sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length
# There is no kv cache when training
past_key_values_length = 0
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
batch = dict(
input_ids=input_ids[:, sub_seq_start:sub_seq_end],
labels=labels[:, sub_seq_start:sub_seq_end],
position_ids=position_ids[:, sub_seq_start:sub_seq_end],
attention_mask=(input_ids != self.tokenizer.pad_token_id),
)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
how long time this pr merge, when can it finish ? ...
cc @SunMarc if you have the bandwidth to take a look!
@zeyugao I carefully read your pull requests for transformers and accelerate, and pulled your code to try training. Now I have encountered a problem: when entering DistributedAttention, the q, k, v before _SeqAllToAll.apply are not [b, s/p, n, h], but still [b, s, n, h]. I checked the modified parts of the data processing, such as accelerate/data_loader.py and transformers/trainer.py, but did not find any relevant processing code. So, may I ask where the sequence splitting is done?
@glowwormX It is in the pr description
@zeyugao My God, I missed it, I thought there was this code in pr. Thank you for replying.
@zeyugao Have you compared the loss of sequence parallel? After a fixed seed is added to DistributedSampler, the training data is the same. Modify the trainer.py:
if is_accelerate_available() and mpu.sequence_parallel_is_enabled():
assert self.args.group_by_length is False, "Group by length is not supported with sequence parallel."
return DistributedSampler(
dataset=self.train_dataset,
num_replicas=mpu.get_data_parallel_world_size(),
rank=mpu.get_data_parallel_rank(),
shuffle=True,
seed=42
)
However, when the same data is calculated, the average loss value after sequence parallel is different from the loss value without sequence parallel.
In addition, what is the reason why starcoder does not support sdpa? I am trying to modify qwen2 and I do not know if it does not support sdpa.
@glowwormX The main reason should be that it need to use custom loss calculation, otherwise there are some tokens (in the head and tail of each subsequence) not contributing to the final loss: https://github.com/microsoft/DeepSpeed/pull/5774/files#diff-13f25bb51b0f4019d8cb09c07204a33510dca5dccfae736baf10134f893704d5
the reason why starcoder does not support sdpa
I do not have much spare time to make the shape correct when using sdpa for startcoder2 at that time
@zeyugao: Your implementation does not use this loss function right? It still works ok even so?
cc @XuehaiPan