transformers
transformers copied to clipboard
`MistralAttention`: where is the sliding window
Hi,
I'm trying to understand the implementation of Mistral's attention in MistralAttention
.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L195
It is my understanding that it should always be using local window attention. In MistralFlashAttention2
this is very obvious, with config.sliding_window
being used.
However, I'm not sure where the sliding window is used in the base MistralAttention
without flash attention:
class MistralAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
but the forward pass simply reads
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
which I understand as full self attention.
Is the sliding window only used when running with Flash Attention, or am I missing something? Thanks!
cc @ArthurZucker @younesbelkada
I think the sliding window trick is based on masking? https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L998-L1018
Thanks, I see. But wouldn't this throw away any computational efficiency gains expected from using a sliding window in the first place?
I have the same question. I think the sliding window has two aspects:
- From the perspective of the attention mask, it essentially acts as a token-level sliding window that influences each token's view of the context.
- From a kv-cache perspective, truncating the cache outside the window can improve computational efficiency.
Just my guess above.
Yes this would throw away the gains, and it is pretty much expected as the best way to use sliding_window
is through the sdpa
or the flash_attention
api, unless a rotating buffer is used.
Closing as expected, feel free to discuss! 🤗
Hi @ArthurZucker interesting - so sdpa
actually exploits the local window structure of the attention mask in the backend?
It should if the mask is correctly passed yeah. New sdpa has the sliding_window
argument anyway. Not sure it was correctly prepared before, important PR: #29407
It should if the mask is correctly passed yeah. New sdpa has the
sliding_window
argument anyway. Not sure it was correctly prepared before, important PR: #29407
@ArthurZucker Did you mention this pr? https://github.com/pytorch/pytorch/pull/114823, which is not use sliding_window param explicitly but can handle the sliding window mask in the sdpa function, am i right? So if we pass the right mask through _prepare_4d_causal_attention_mask_for_sdpa as you mentioned here, https://github.com/huggingface/transformers/pull/29407, we can use local window feature of Mistral. But I think we can still gain some computational efficiency with local attention without the rotating buffer, because of the sparsity of attention mask of sliding window attention.
Sorry I meant the new SDPA codepath in transformers but it's not merged yet, yes as you say handles the mask