transformers icon indicating copy to clipboard operation
transformers copied to clipboard

`MistralAttention`: where is the sliding window

Open fteufel opened this issue 11 months ago • 4 comments

Hi,

I'm trying to understand the implementation of Mistral's attention in MistralAttention. https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L195 It is my understanding that it should always be using local window attention. In MistralFlashAttention2 this is very obvious, with config.sliding_window being used.

However, I'm not sure where the sliding window is used in the base MistralAttention without flash attention:

class MistralAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

but the forward pass simply reads

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

which I understand as full self attention.

Is the sliding window only used when running with Flash Attention, or am I missing something? Thanks!

fteufel avatar Mar 21 '24 12:03 fteufel

cc @ArthurZucker @younesbelkada

amyeroberts avatar Mar 21 '24 15:03 amyeroberts

I think the sliding window trick is based on masking? https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L998-L1018

PenutChen avatar Mar 22 '24 01:03 PenutChen

Thanks, I see. But wouldn't this throw away any computational efficiency gains expected from using a sliding window in the first place?

fteufel avatar Mar 22 '24 09:03 fteufel

I have the same question. I think the sliding window has two aspects:

  1. From the perspective of the attention mask, it essentially acts as a token-level sliding window that influences each token's view of the context.
  2. From a kv-cache perspective, truncating the cache outside the window can improve computational efficiency.

Just my guess above.

PenutChen avatar Mar 22 '24 09:03 PenutChen

Yes this would throw away the gains, and it is pretty much expected as the best way to use sliding_window is through the sdpa or the flash_attention api, unless a rotating buffer is used.

Closing as expected, feel free to discuss! 🤗

ArthurZucker avatar Mar 25 '24 09:03 ArthurZucker

Hi @ArthurZucker interesting - so sdpa actually exploits the local window structure of the attention mask in the backend?

fteufel avatar Mar 25 '24 09:03 fteufel

It should if the mask is correctly passed yeah. New sdpa has the sliding_window argument anyway. Not sure it was correctly prepared before, important PR: #29407

ArthurZucker avatar Mar 26 '24 13:03 ArthurZucker

It should if the mask is correctly passed yeah. New sdpa has the sliding_window argument anyway. Not sure it was correctly prepared before, important PR: #29407

@ArthurZucker Did you mention this pr? https://github.com/pytorch/pytorch/pull/114823, which is not use sliding_window param explicitly but can handle the sliding window mask in the sdpa function, am i right? So if we pass the right mask through _prepare_4d_causal_attention_mask_for_sdpa as you mentioned here, https://github.com/huggingface/transformers/pull/29407, we can use local window feature of Mistral. But I think we can still gain some computational efficiency with local attention without the rotating buffer, because of the sparsity of attention mask of sliding window attention.

ehuaa avatar Mar 27 '24 13:03 ehuaa

Sorry I meant the new SDPA codepath in transformers but it's not merged yet, yes as you say handles the mask

ArthurZucker avatar Mar 27 '24 15:03 ArthurZucker