Add masking of different samples in a long sequence for flash-attention mechanism
What does this PR do?
Fixes # (issue)
In LLM training, we always choose to pack short samples in one sequence for efficient training. In this situation, it is ideal to do masking for different samples.
For casual self-attention implementation, we can use a 3-D mask matrix to mask different samples. But, for flash-attention which do not support a casual 3-D mask matrix, we need a shortcut.
The attention_mask_in_length is utilized to mask other short samples. The motivation for this function is explained here.
- We can utilize
attention_mask_in_lengthto get indices and lengths of all short samples. - Next, long sequence embeddings and length indicators are fed into the Flash attention mechanism to obtain its outputs.
- Finally, through the use of an inverse operation, we can rearrange the outputs to match the shape of the original batch.
An example of attention_mask_in_length
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
[
[2, 3],
[3, 2],
[6, 0]
]
, which refers to the 3D-attention mask:
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
Use sample masking of flash-attention for Llama3
https://github.com/sz128/LLMs_implementations/blob/main/sample_mask_with_flash-attn-2.ipynb
Here is the core code to get attention_mask_in_length for megatron-LM data collator (https://github.com/microsoft/Megatron-DeepSpeed/blob/b7b2d5ef330f43729b406630e6c5d38e873d7398/megatron/utils.py#L162):
def mask_concat_samples(batch_data, eos_token_id, reset_position_ids=False):
input_ids = batch_data["input_ids"]
labels = batch_data["labels"].clone()
micro_batch_size, seq_length = input_ids.shape
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
inner_sample_lengths = torch.zeros((micro_batch_size, seq_length), dtype=torch.int)
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, input_ids[b] == eos_token_id]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
prev_index = -1
for j in range(len(eod_index)):
inner_sample_lengths[b, j] = eod_index[j] - prev_index
prev_index = eod_index[j]
if eod_index[j] < seq_length - 1:
labels[b, eod_index[j]+1] = -100
if prev_index < seq_length - 1:
inner_sample_lengths[b, len(eod_index)] = seq_length - 1 - prev_index
#print(len(input_ids[b]), sum(inner_sample_lengths[b]))
assert len(input_ids[b]) == sum(inner_sample_lengths[b]).item()
if reset_position_ids and len(eod_index) > 1:
for j in range(1, len(eod_index)):
i = eod_index[j]
prev_len = eod_index[j-1]
position_ids[b, i:] -= (i - prev_len)
batch_data["labels"] = labels
batch_data["attention_mask"] = inner_sample_lengths
if reset_position_ids:
batch_data["position_ids"] = position_ids
We can also change attention_mask for DataCollatorWithFlattening
https://github.com/huggingface/transformers/blob/23d2c69a527e5d1868999c5693b7e108e29563aa/src/transformers/data/data_collator.py#L1617-L1663
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
cc @ArthurZucker
Isn't this already supported with https://github.com/huggingface/transformers/pull/31629 ?
Isn't this already supported with #31629 ?
It seems that both implementations are similar. But, we need to consider the situation that position ids are not reset between different samples (especially for LLM pre-training).
Hey! Don't you think that ragging the tensor would be more efficient?
Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and https://github.com/huggingface/transformers/pull/31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.
While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.
Hey! Don't you think that ragging the tensor would be more efficient?
Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.
While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.
dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers.
Hey! Don't you think that ragging the tensor would be more efficient?
Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix. While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.
dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers.
Ohh, I mean it is useful in long-context pre-training.
For long-context training (for example, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. In this case, position IDs should not be reset.
we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training
does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!
we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training
does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!
No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.
we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training
does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!
No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.
A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃♂️
we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training
does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!
No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.
A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃♂️
You can take a look at https://github.com/huggingface/transformers/issues/14767.
Why wouldn't we use position_ids to encode all information (packed, not packed, padded, not padded) in a slightly more elegant way without touching attention_mask?
For example let's say that the sequence length is 8:
- for perfectly packed sequences (e.g. two sequences of length 4):
attention_mask = [1, 1, 1, 1, 1, 1, 1, 1]andposition_ids = [0, 1, 2, 3, 0, 1, 2, 3] - for partially packed sequences (e.g. two sequences of length 3 and the rest are padding tokens):
attention_mask = [1, 1, 1, 1, 1, 1, 0, 0]andposition_ids = [0, 1, 2, 0, 1, 2, 0, 1]
When the information about sequences is stored in position_ids (contrary to attention_mask) the positional embeddings are automatically calculated in a correct way, so it is minimally invasive approach relative to the existing transformers codebase.
Why wouldn't we use
position_idsto encode all information (packed, not packed, padded, not padded) in a slightly more elegant way without touchingattention_mask?For example let's say that the sequence length is 8:
- for perfectly packed sequences (e.g. two sequences of length 4):
attention_mask = [1, 1, 1, 1, 1, 1, 1, 1]andposition_ids = [0, 1, 2, 3, 0, 1, 2, 3]- for partially packed sequences (e.g. two sequences of length 3 and the rest are padding tokens):
attention_mask = [1, 1, 1, 1, 1, 1, 0, 0]andposition_ids = [0, 1, 2, 0, 1, 2, 0, 1]When the information about sequences is stored in
position_ids(contrary toattention_mask) the positional embeddings are automatically calculated in a correct way, so it is minimally invasive approach relative to the existing transformers codebase.
For long-context training (e.g., 128k long sequence), we may utilize synthesized samples, which are usually concatenations of short samples. Thus, we don't want to reset position_ids in this case.
For instance, consider four sequences of length 32k. The attention_mask would be [32k, 32k, 32k, 32k, 0, ..., 0, 0, 0], and the position_ids would be [0, 1, 2, 3, 4, 5, 6, 7, ..., 128k-1]. This allows the model to learn position embeddings for longer sequences.
In contrast, if we use attention_mask = [1, 1, 1, ..., 1, 1] and position_ids = [0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1], the model can only learn position embeddings in the range of [0, 32k-1].
Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator.
If they wish to treat such concatenated sequences as a single sequence they would set position_ids = [0, 1, 2, ..., 128k - 1].
If they wish to treat them as multiple shorter sequences concatenated together they would set position_ids = [0, 1, 2, ... , 32k - 1, 0, 1, 2, ..., 32k - 1, 0, 1, 2, 32k - 1].
Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator. If they wish to treat such concatenated sequences as a single sequence they would set
position_ids = [0, 1, 2, ..., 128k - 1]. If they wish to treat them as multiple shorter sequences concatenated together they would setposition_ids = [0, 1, 2, ... , 32k - 1, 0, 1, 2, ..., 32k - 1, 0, 1, 2, 32k - 1].
Well, the point is that when we treat concatenated sequences as a single sequence (position_ids = [0, 1, 2, ..., 128k - 1]), we still need an attention mask to prevent self-attention between different samples within the same sequence. This attention mask is important for very long context training, which is discussed in the LLaMA-3 paper (https://arxiv.org/pdf/2407.21783, the bottom of page 6).
Are you suggesting that they are using position_ids = [0, 1, 2, ..., len(concatenated_sequences) - 1] with flash_attn_varlen_func to prevent cross-document attention? If yes, I feel this is doing mix-and-match: for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.
We use an attention mask that prevents self-attention between different documents within the same sequence.
^I have interpreted this sentence in a different way. They concatenate sequences together, and make sure that there is no cross-document attention, which would translate into:
position_id = [0, 1, 2, ..., len(seq1) - 1, 0, 1, 2, ..., len(seq2) - 1, ...]
which we would use to correctly compute positional embeddings and to figure out where each sequence starts and ends (cu_seq_lens), and use that to call flash_attn_varlen_func https://github.com/huggingface/transformers/blob/ecd61c62862f925a18b4f063dc17fcaf01826e25/src/transformers/modeling_flash_attention_utils.py#L270-L293
for positional embeddings (which are computed based on
position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.
Yes. We want to learn position embeddings for larger position ids (32k~128k-1).
for positional embeddings (which are computed based on
position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.Yes. We want to learn position embeddings for larger position ids (32k~128k-1).
Are there any ablation studies which show that this mix-and-match approach helps?
Since most LLMs use ROPE positions, which are relative, there is no difference between resetting position IDs or not for cross-document attention masking.