transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Allow passing 2D attention mask

Open UniverseFly opened this issue 9 months ago • 4 comments

Feature request

Allow passing a 2D attention mask in model.forward.

Motivation

With this feature, it would be much easier to avoid cross-context contamination during pretraining and supervised finetuning when packing the sequences together for more efficient training.

Here is an example usecase discussed in (https://github.com/huggingface/trl/issues/805):

Your contribution

Upon investigation into the source code, I found the current logic of initializing attention masks is mostly a fixed code snippet encoded in each model:

        if getattr(self.config, "_flash_attn_2_enabled", False):
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )

To enable this behavior may require hacking into each model. I should be able to handle part of them and submit a draft PR. But before that, I want to know if this feature request is reasonable.

UniverseFly avatar Nov 21 '23 18:11 UniverseFly

Hey, the model's forward already supports passing a 2d attention mask, it is just expended to 4d because that is the format required by the attention implementation. Would you mind elaborating on what you cannot currently do? (Might be related to #27539?)

ArthurZucker avatar Nov 22 '23 10:11 ArthurZucker

Hey, the model's forward already supports passing a 2d attention mask, it is just expended to 4d because that is the format required by the attention implementation. Would you mind elaborating on what you cannot currently do? (Might be related to #27539?)

Yeah, I might not make it clear. The current "2D"s are [batch_size, num_tokens]. What I suggested was [batch_size, num_tokens, num_tokens] so we can have a matrix for each batch that explicitly defines what each token should attend to. https://github.com/huggingface/transformers/pull/27539 seems relevant

UniverseFly avatar Nov 22 '23 17:11 UniverseFly

Just chiming in, here is some more context (also very interested in this feature). From what I understand, this is not trivial implement in general...

As one current example, the axolotl finetuning harness implements efficient sample packing with correct block diagonal attention masking through a series of monkey patches for the underlying huggingface model definitions for a few of the very popular models like llama and mistral. Though I have not looked through the code in detail, I believe it leverages the fact that the flash attention api supports the masking required to implement this scheme.

It is relevant for efficient finetuning (the reason it's incorporated into axolotl), and general wisdom (and whispers from inside large corps) suggest that this type of block diagonal masking is better for large scale training code.

(https://github.com/huggingface/transformers/pull/27539 is relevant, but it looks like the focus may be on the beam search/speculative decoding use case, not this slightly more general use case. Also here's a relevant hf forum post https://discuss.huggingface.co/t/the-correct-attention-mask-for-examples-packing/52909/2)

jwkirchenbauer avatar Dec 01 '23 02:12 jwkirchenbauer

Packing is indeed a good use-case for supporting 2D attention mask for huggingface models.

meliksahturker avatar May 09 '24 21:05 meliksahturker