transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Support deepspeed sequence parallel

Open zeyugao opened this issue 1 year ago • 10 comments

What does this PR do?

Support the sequence parallel with Deepspeed-Ulysses.

I have tested the training on starcoder2-3b. The loss decreases normally.

CleanShot 2024-06-21 at 00 52 50@2x

Requires https://github.com/huggingface/accelerate/pull/2877

~~I have made massive modifications to the original implementation of Deepspeed-Ulysses to support batch size dim in layers.py. It uses all_to_all_single instead of all_to_all like https://github.com/InternLM/InternEvo/blob/a61d391df96c5f5c243cdea32a5044b70d6fe33e/internlm/core/parallel/comm/isp.py#L628 for better performance. I have left some comments to help the future understanding.~~ Use all_to_all_single is too complex to support other scatter idx and gather idx

Currently, flash attn and sdpa for llama and mistral are tested. flash attn for starcoder is also tested, the sdpa for starcoder is not supported.

It requires a special dataloader (I have made in Trainer) and data collator (with example followed). In data collator, the sequence should be divided into multiple sub-sequences. The following is an example of sub-sequences processing in the data collator.

            seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
            seq_parallel_world_rank = mpu.get_sequence_parallel_rank()

            seq_length = input_ids.size(1)
            sub_seq_length = seq_length // seq_parallel_world_size
            sub_seq_start = seq_parallel_world_rank * sub_seq_length
            sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length

            # There is no kv cache when training
            past_key_values_length = 0

            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

            batch = dict(
                input_ids=input_ids[:, sub_seq_start:sub_seq_end],
                labels=labels[:, sub_seq_start:sub_seq_end],
                position_ids=position_ids[:, sub_seq_start:sub_seq_end],
                attention_mask=(input_ids != self.tokenizer.pad_token_id),
            )

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@muellerzr and @SunMarc

zeyugao avatar Jun 20 '24 16:06 zeyugao

Great, can you provide an example of data processing based on sequence paralleler? thanks

fan-niu avatar Jun 27 '24 02:06 fan-niu

The dataset and sampler are handled in the Trainer

https://github.com/huggingface/transformers/pull/31525/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR847-R855

The data collator example is accidentally deleted when editing

            seq_parallel_world_size = mpu.get_sequence_parallel_world_size()
            seq_parallel_world_rank = mpu.get_sequence_parallel_rank()

            seq_length = input_ids.size(1)
            sub_seq_length = seq_length // seq_parallel_world_size
            sub_seq_start = seq_parallel_world_rank * sub_seq_length
            sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_length

            # There is no kv cache when training
            past_key_values_length = 0

            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long,
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

            batch = dict(
                input_ids=input_ids[:, sub_seq_start:sub_seq_end],
                labels=labels[:, sub_seq_start:sub_seq_end],
                position_ids=position_ids[:, sub_seq_start:sub_seq_end],
                attention_mask=(input_ids != self.tokenizer.pad_token_id),
            )

zeyugao avatar Jun 27 '24 04:06 zeyugao

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Aug 10 '24 08:08 github-actions[bot]

how long time this pr merge, when can it finish ? ...

ldh127 avatar Sep 22 '24 12:09 ldh127

cc @SunMarc if you have the bandwidth to take a look!

LysandreJik avatar Sep 27 '24 09:09 LysandreJik

@zeyugao I carefully read your pull requests for transformers and accelerate, and pulled your code to try training. Now I have encountered a problem: when entering DistributedAttention, the q, k, v before _SeqAllToAll.apply are not [b, s/p, n, h], but still [b, s, n, h]. I checked the modified parts of the data processing, such as accelerate/data_loader.py and transformers/trainer.py, but did not find any relevant processing code. So, may I ask where the sequence splitting is done?

glowwormX avatar Oct 08 '24 02:10 glowwormX

@glowwormX It is in the pr description

image

zeyugao avatar Oct 08 '24 02:10 zeyugao

@zeyugao My God, I missed it, I thought there was this code in pr. Thank you for replying.

glowwormX avatar Oct 08 '24 02:10 glowwormX

@zeyugao Have you compared the loss of sequence parallel? After a fixed seed is added to DistributedSampler, the training data is the same. Modify the trainer.py:

        if is_accelerate_available() and mpu.sequence_parallel_is_enabled():
            assert self.args.group_by_length is False, "Group by length is not supported with sequence parallel."
            return DistributedSampler(
                dataset=self.train_dataset,
                num_replicas=mpu.get_data_parallel_world_size(),
                rank=mpu.get_data_parallel_rank(),
                shuffle=True,
                seed=42
            )

However, when the same data is calculated, the average loss value after sequence parallel is different from the loss value without sequence parallel.

In addition, what is the reason why starcoder does not support sdpa? I am trying to modify qwen2 and I do not know if it does not support sdpa.

glowwormX avatar Oct 17 '24 09:10 glowwormX

@glowwormX The main reason should be that it need to use custom loss calculation, otherwise there are some tokens (in the head and tail of each subsequence) not contributing to the final loss: https://github.com/microsoft/DeepSpeed/pull/5774/files#diff-13f25bb51b0f4019d8cb09c07204a33510dca5dccfae736baf10134f893704d5

the reason why starcoder does not support sdpa

I do not have much spare time to make the shape correct when using sdpa for startcoder2 at that time

zeyugao avatar Oct 19 '24 03:10 zeyugao

@zeyugao: Your implementation does not use this loss function right? It still works ok even so?

ronald-d-rogers avatar Nov 26 '24 04:11 ronald-d-rogers

cc @XuehaiPan

SunMarc avatar Dec 30 '24 14:12 SunMarc