transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add masking of different samples in a long sequence for flash-attention mechanism

Open sz128 opened this issue 1 year ago • 4 comments

What does this PR do?

Fixes # (issue)

In LLM training, we always choose to pack short samples in one sequence for efficient training. In this situation, it is ideal to do masking for different samples.

For casual self-attention implementation, we can use a 3-D mask matrix to mask different samples. But, for flash-attention which do not support a casual 3-D mask matrix, we need a shortcut.

The attention_mask_in_length is utilized to mask other short samples. The motivation for this function is explained here.

  1. We can utilize attention_mask_in_length to get indices and lengths of all short samples.
  2. Next, long sequence embeddings and length indicators are fed into the Flash attention mechanism to obtain its outputs.
  3. Finally, through the use of an inverse operation, we can rearrange the outputs to match the shape of the original batch.
An example of attention_mask_in_length

For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:

        [
          [2, 3],
          [3, 2],
          [6, 0]
        ]

, which refers to the 3D-attention mask:

        [
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0],
            [0, 0, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 0],
            [0, 0, 0, 0, 0, 1]
          ],
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [0, 0, 0, 1, 0, 0],
            [0, 0, 0, 1, 1, 0],
            [0, 0, 0, 0, 0, 1]
          ],
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 0, 0],
            [1, 1, 1, 1, 1, 0],
            [1, 1, 1, 1, 1, 1]
          ]
        ]
Use sample masking of flash-attention for Llama3

https://github.com/sz128/LLMs_implementations/blob/main/sample_mask_with_flash-attn-2.ipynb

Here is the core code to get attention_mask_in_length for megatron-LM data collator (https://github.com/microsoft/Megatron-DeepSpeed/blob/b7b2d5ef330f43729b406630e6c5d38e873d7398/megatron/utils.py#L162):

def mask_concat_samples(batch_data, eos_token_id, reset_position_ids=False):
    input_ids = batch_data["input_ids"]
    labels = batch_data["labels"].clone()
    micro_batch_size, seq_length = input_ids.shape

    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

    inner_sample_lengths = torch.zeros((micro_batch_size, seq_length), dtype=torch.int)
    for b in range(micro_batch_size):
        # Find indecies where EOD token is.
        eod_index = position_ids[b, input_ids[b] == eos_token_id]
        # Detach indecies from positions if going to modify positions.
        if reset_position_ids:
            eod_index = eod_index.clone()

        prev_index = -1
        for j in range(len(eod_index)):
            inner_sample_lengths[b, j] = eod_index[j] - prev_index
            prev_index = eod_index[j]
            if eod_index[j] < seq_length - 1:
                labels[b, eod_index[j]+1] = -100

        if prev_index < seq_length - 1:
            inner_sample_lengths[b, len(eod_index)] = seq_length - 1 - prev_index

        #print(len(input_ids[b]), sum(inner_sample_lengths[b]))
        assert len(input_ids[b]) == sum(inner_sample_lengths[b]).item()

        if reset_position_ids and len(eod_index) > 1:
            for j in range(1, len(eod_index)):
                i = eod_index[j]
                prev_len = eod_index[j-1]
                position_ids[b, i:] -= (i - prev_len)

    batch_data["labels"] = labels
    batch_data["attention_mask"] = inner_sample_lengths

    if reset_position_ids:
        batch_data["position_ids"] = position_ids

We can also change attention_mask for DataCollatorWithFlattening https://github.com/huggingface/transformers/blob/23d2c69a527e5d1868999c5693b7e108e29563aa/src/transformers/data/data_collator.py#L1617-L1663

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

sz128 avatar Aug 19 '24 08:08 sz128

cc @ArthurZucker

amyeroberts avatar Aug 19 '24 08:08 amyeroberts

Isn't this already supported with https://github.com/huggingface/transformers/pull/31629 ?

eldarkurtic avatar Aug 19 '24 14:08 eldarkurtic

Isn't this already supported with #31629 ?

It seems that both implementations are similar. But, we need to consider the situation that position ids are not reset between different samples (especially for LLM pre-training).

sz128 avatar Aug 21 '24 03:08 sz128

Hey! Don't you think that ragging the tensor would be more efficient?

Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and https://github.com/huggingface/transformers/pull/31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.

While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.

sz128 avatar Aug 21 '24 12:08 sz128

Hey! Don't you think that ragging the tensor would be more efficient?

Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.

While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.

dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers.

beep-bebop avatar Aug 26 '24 02:08 beep-bebop

Hey! Don't you think that ragging the tensor would be more efficient?

Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix. While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.

dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers.

Ohh, I mean it is useful in long-context pre-training.

For long-context training (for example, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. In this case, position IDs should not be reset.

sz128 avatar Aug 26 '24 03:08 sz128

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

ArthurZucker avatar Aug 28 '24 09:08 ArthurZucker

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.

sz128 avatar Aug 28 '24 11:08 sz128

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.

A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃‍♂️

beep-bebop avatar Sep 02 '24 07:09 beep-bebop

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.

A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃‍♂️

You can take a look at https://github.com/huggingface/transformers/issues/14767.

sz128 avatar Sep 03 '24 12:09 sz128

Why wouldn't we use position_ids to encode all information (packed, not packed, padded, not padded) in a slightly more elegant way without touching attention_mask?

For example let's say that the sequence length is 8:

  1. for perfectly packed sequences (e.g. two sequences of length 4): attention_mask = [1, 1, 1, 1, 1, 1, 1, 1] and position_ids = [0, 1, 2, 3, 0, 1, 2, 3]
  2. for partially packed sequences (e.g. two sequences of length 3 and the rest are padding tokens): attention_mask = [1, 1, 1, 1, 1, 1, 0, 0] and position_ids = [0, 1, 2, 0, 1, 2, 0, 1]

When the information about sequences is stored in position_ids (contrary to attention_mask) the positional embeddings are automatically calculated in a correct way, so it is minimally invasive approach relative to the existing transformers codebase.

eldarkurtic avatar Sep 03 '24 14:09 eldarkurtic

Why wouldn't we use position_ids to encode all information (packed, not packed, padded, not padded) in a slightly more elegant way without touching attention_mask?

For example let's say that the sequence length is 8:

  1. for perfectly packed sequences (e.g. two sequences of length 4): attention_mask = [1, 1, 1, 1, 1, 1, 1, 1] and position_ids = [0, 1, 2, 3, 0, 1, 2, 3]
  2. for partially packed sequences (e.g. two sequences of length 3 and the rest are padding tokens): attention_mask = [1, 1, 1, 1, 1, 1, 0, 0] and position_ids = [0, 1, 2, 0, 1, 2, 0, 1]

When the information about sequences is stored in position_ids (contrary to attention_mask) the positional embeddings are automatically calculated in a correct way, so it is minimally invasive approach relative to the existing transformers codebase.

For long-context training (e.g., 128k long sequence), we may utilize synthesized samples, which are usually concatenations of short samples. Thus, we don't want to reset position_ids in this case.

For instance, consider four sequences of length 32k. The attention_mask would be [32k, 32k, 32k, 32k, 0, ..., 0, 0, 0], and the position_ids would be [0, 1, 2, 3, 4, 5, 6, 7, ..., 128k-1]. This allows the model to learn position embeddings for longer sequences.

In contrast, if we use attention_mask = [1, 1, 1, ..., 1, 1] and position_ids = [0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1], the model can only learn position embeddings in the range of [0, 32k-1].

sz128 avatar Sep 04 '24 03:09 sz128

Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator. If they wish to treat such concatenated sequences as a single sequence they would set position_ids = [0, 1, 2, ..., 128k - 1]. If they wish to treat them as multiple shorter sequences concatenated together they would set position_ids = [0, 1, 2, ... , 32k - 1, 0, 1, 2, ..., 32k - 1, 0, 1, 2, 32k - 1].

eldarkurtic avatar Sep 04 '24 06:09 eldarkurtic

Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator. If they wish to treat such concatenated sequences as a single sequence they would set position_ids = [0, 1, 2, ..., 128k - 1]. If they wish to treat them as multiple shorter sequences concatenated together they would set position_ids = [0, 1, 2, ... , 32k - 1, 0, 1, 2, ..., 32k - 1, 0, 1, 2, 32k - 1].

Well, the point is that when we treat concatenated sequences as a single sequence (position_ids = [0, 1, 2, ..., 128k - 1]), we still need an attention mask to prevent self-attention between different samples within the same sequence. This attention mask is important for very long context training, which is discussed in the LLaMA-3 paper (https://arxiv.org/pdf/2407.21783, the bottom of page 6).

sz128 avatar Sep 04 '24 07:09 sz128

Are you suggesting that they are using position_ids = [0, 1, 2, ..., len(concatenated_sequences) - 1] with flash_attn_varlen_func to prevent cross-document attention? If yes, I feel this is doing mix-and-match: for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.

We use an attention mask that prevents self-attention between different documents within the same sequence.

^I have interpreted this sentence in a different way. They concatenate sequences together, and make sure that there is no cross-document attention, which would translate into: position_id = [0, 1, 2, ..., len(seq1) - 1, 0, 1, 2, ..., len(seq2) - 1, ...] which we would use to correctly compute positional embeddings and to figure out where each sequence starts and ends (cu_seq_lens), and use that to call flash_attn_varlen_func https://github.com/huggingface/transformers/blob/ecd61c62862f925a18b4f063dc17fcaf01826e25/src/transformers/modeling_flash_attention_utils.py#L270-L293

eldarkurtic avatar Sep 04 '24 09:09 eldarkurtic

for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.

Yes. We want to learn position embeddings for larger position ids (32k~128k-1).

sz128 avatar Sep 05 '24 03:09 sz128

for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.

Yes. We want to learn position embeddings for larger position ids (32k~128k-1).

Are there any ablation studies which show that this mix-and-match approach helps?

eldarkurtic avatar Sep 05 '24 04:09 eldarkurtic

Since most LLMs use ROPE positions, which are relative, there is no difference between resetting position IDs or not for cross-document attention masking.

sz128 avatar Sep 10 '24 06:09 sz128